Unverified Commit f0416ebf authored by NingMa's avatar NingMa Committed by GitHub

Add files via upload

parent fec106d0
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader
from usps import *
from data_pro.office_home import ImageList
import random
class PlaceCrop(object):
"""Crops the given PIL.Image at the particular index.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (w, h), a square crop (size, size) is
made.
"""
def __init__(self, size, start_x, start_y):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
self.start_x = start_x
self.start_y = start_y
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be cropped.
Returns:
PIL.Image: Cropped image.
"""
th, tw = self.size
return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th))
class ResizeImage():
def __init__(self, size):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
th, tw = self.size
return img.resize((th, tw))
def image_train(resize_size=256, crop_size=224):
return transforms.Compose([
transforms.Resize((resize_size, resize_size)),
transforms.RandomCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def image_test(resize_size=256, crop_size=224):
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# start_center = (resize_size - crop_size - 1) / 2
return transforms.Compose([
transforms.Resize((resize_size, resize_size)),
transforms.CenterCrop(crop_size),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def get_data_loader(args):
dataset=args.dataset
source_path =None
target_path=None
train_bs=args.batch_size
dset_loaders = {}
def split(train_r,source_path):
with open(source_path,'r') as f:
data=f.readlines()
train_len=int(len(data)*train_r)
train,val=torch.utils.data.random_split(data, [train_len, len(data) - train_len])
return train,val
if dataset == "digits":
return get_data_loader(args)
elif dataset == "office_home":
source_path=os.path.join("./data/office-home/",args.s+".txt")
target_path=os.path.join("./data/office-home/",args.t+".txt")
elif dataset == "office-31":
source_path=os.path.join("./data/office/",args.s+".txt")
target_path=os.path.join("./data/office/",args.t+".txt")
train_data,val_dada=split(train_r=0.9,source_path=source_path)
# print("train-total:",len(train_data)+len(val_dada),"train_train",len(train_data),"train_val",len(val_dada))
train_source = ImageList(train_data, transform=image_train())
test_source = ImageList(val_dada, transform=image_test())
train_target = ImageList(open(target_path).readlines(),transform=image_train())
train_target.return_index = 1
test_target = ImageList(open(target_path).readlines(), transform=image_test())
test_target.return_index=1
dset_loaders["source_tr"] = DataLoader(train_source, batch_size=train_bs, shuffle=True,
num_workers=args.worker, drop_last=False)
dset_loaders["source_te"] = DataLoader(test_source, batch_size=train_bs * 2, shuffle=True,
num_workers=args.worker, drop_last=False)
dset_loaders["target"] = DataLoader(train_target, batch_size=train_bs, shuffle=True,
num_workers=args.worker, drop_last=False)
dset_loaders["test"] = DataLoader(test_target, batch_size=train_bs * 2, shuffle=False,
num_workers=args.worker, drop_last=False)
return dset_loaders
def digit_load(args):
train_bs = args.batch_size
if args.s+"2"+args.t == 's2m':
train_source = torchvision.datasets.SVHN('../data/digit/svhn/', split='train', download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
test_source = torchvision.datasets.SVHN('../data/digit/svhn/', split='test', download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
train_target = torchvision.datasets.MNIST('../data/digit/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
test_target = torchvision.datasets.MNIST('../data/digit/mnist/', train=False, download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
elif args.s+"2"+args.t == 'u2m':
train_source = USPS('../data/digit/usps/', train=True, download=True,
transform=transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
test_source = USPS('../data/digit/usps/', train=False, download=True,
transform=transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
train_target = torchvision.datasets.MNIST('../data/digit/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
test_target = torchvision.datasets.MNIST('../data/digit/mnist/', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
# train_target=train_source
# test_target=test_source
elif args.s+"2"+args.t == 'm2u':
train_source = torchvision.datasets.MNIST('../data/digit/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
test_source = torchvision.datasets.MNIST('../data/digit/mnist/', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
train_target = USPS('../data/digit/usps/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
test_target = USPS('../data/digit/usps/', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
if args.s+"2"+args.t == 'm2s':
train_target = torchvision.datasets.SVHN('../data/digit/svhn/', split='train', download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
test_target = torchvision.datasets.SVHN('../data/digit/svhn/', split='test', download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
train_source= torchvision.datasets.MNIST('../data/digit/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
test_source = torchvision.datasets.MNIST('../data/digit/mnist/', train=False, download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
dset_loaders = {}
dset_loaders["source_tr"] = DataLoader(train_source, batch_size=train_bs, shuffle=True,
num_workers=args.worker, drop_last=False)
dset_loaders["source_te"] = DataLoader(test_source, batch_size=train_bs * 2, shuffle=True,
num_workers=args.worker, drop_last=False)
dset_loaders["target"] = DataLoader(train_target, batch_size=train_bs, shuffle=True,
num_workers=args.worker, drop_last=False)
dset_loaders["test"] = DataLoader(test_target, batch_size=train_bs * 2, shuffle=False,
num_workers=args.worker, drop_last=False)
return dset_loaders
# class ImageList_idx(Dataset):
# def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
# imgs = make_dataset(image_list, labels)
# if len(imgs) == 0:
# raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
#
# self.imgs = imgs
# self.transform = transform
# self.target_transform = target_transform
# if mode == 'RGB':
# self.loader = rgb_loader
# elif mode == 'L':
# self.loader = l_loader
#
# def __getitem__(self, index):
# path, target = self.imgs[index]
# img = self.loader(path)
# if self.transform is not None:
# img = self.transform(img)
# if self.target_transform is not None:
# target = self.target_transform(target)
#
# return img, target, index
#
# def __len__(self):
# return len(self.imgs)
\ No newline at end of file
import numpy as np
import os
import os.path
from PIL import Image
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def make_dataset_fromlist(image_list):
# with open(image_list) as f:
image_index = [x.split(' ')[0] for x in image_list]
# with open(image_list) as f:
label_list = []
selected_list = []
for ind, x in enumerate(image_list):
label = x.split(' ')[1].strip()
label_list.append(int(label))
selected_list.append(ind)
image_index = np.array(image_index)
label_list = np.array(label_list)
image_index = image_index[selected_list]
return image_index, label_list
def return_classlist(image_list):
with open(image_list) as f:
label_list = []
for ind, x in enumerate(f.readlines()):
label = x.split(' ')[0].split('/')[-2]
if label not in label_list:
label_list.append(str(label))
return label_list
class Imagelists_VISDA(object):
def __init__(self, image_list, root="./data/multi/",
transform=None, target_transform=None, test=False):
imgs, labels = make_dataset_fromlist(image_list)
self.imgs = imgs
self.labels = labels
self.transform = transform
self.target_transform = target_transform
self.loader = pil_loader
self.root = root
self.test = test
self.return_index=False
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is
class_index of the target class.
"""
path = os.path.join(self.root, self.imgs[index])
target = self.labels[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if not self.test:
if self.return_index:
return img, target,index
else:
return img, target
else:
return img, target, self.imgs[index]
def __len__(self):
return len(self.imgs)
class STL(object):
def __init__(self, root="./data/multi/", ttype="labeled",
transform=None, target_transform=None, test=False):
# imgs, labels = make_dataset_fromlist(image_list)
if ttype=="labeled":
imgs=np.load(os.path.join(root,"labeled_data"))
lables=np.load(os.path.join(root,"labeled_label"))
elif ttype=="unlabeled":
imgs=np.load(os.path.join(root,"unlabeled_data"))
lables=np.load(os.path.join(root,"unlabeled_label"))
else:
imgs=np.load(os.path.join(root,"test_data"))
lables=np.load(os.path.join(root,"test_label"))
self.imgs = imgs
self.labels = labels
self.transform = transform
self.target_transform = target_transform
self.loader = pil_loader
self.root = root
self.test = test
self.return_index=False
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is
class_index of the target class.
"""
# path = os.path.join(self.root, self.imgs[index])
target = self.labels[index]
img = Image.fromarray(self.imgs[index])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if not self.test:
if self.return_index:
return img, target,index
else:
return img, target
else:
return img, target, self.imgs[index]
def __len__(self):
return len(self.imgs)
\ No newline at end of file
class BaseDataLoader():
def __init__(self):
pass
def initialize(self,batch_size):
self.batch_size = batch_size
self.serial_batches = 0
self.nThreads = 2
self.max_dataset_size=float("inf")
pass
def load_data():
return None
import torch
import numpy as np
import random
from PIL import Image
from torch.utils.data import Dataset,Sampler
from collections import defaultdict
import os
import os.path
import cv2
import torchvision
def make_dataset(image_list, labels):
if labels:
len_ = len(image_list)
images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
else:
if len(image_list[0].split()) > 2:
images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
else:
images = [(val.split()[0], int(val.split()[1])) for val in image_list]
return images
def rgb_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def l_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('L')
class ImageList(Dataset):
def __init__(self, image_list, root=None,labels=None, transform=None, target_transform=None, mode='RGB'):
imgs = make_dataset(image_list, labels)
# if len(imgs) == 0:
# raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.imgs = imgs
self.root = root
self.transform = transform
self.target_transform = target_transform
if mode == 'RGB':
self.loader = rgb_loader
elif mode == 'L':
self.loader = l_loader
self.classwise_indices = defaultdict(list)
self.idx2class = []
for [_, target] in self.imgs:
# print(target)
self.idx2class.append(target)
# self.idx2class=[target for i,[_,target] in enumerate(self.imgs)]
for idx, classes in enumerate(self.idx2class):
self.classwise_indices[classes].append(idx)
def get_class(self, idx):
return self.idx2class[idx]
def __getitem__(self, index):
# print("index",index)
path, target = self.imgs[index]
path=os.path.join(self.root,path)
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
class ImageList_idx(Dataset):
def __init__(self, image_list, root=None,labels=None, transform=None, target_transform=None, mode='RGB'):
imgs = make_dataset(image_list, labels)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.imgs = imgs
self.root = root
self.transform = transform
self.target_transform = target_transform
if mode == 'RGB':
self.loader = rgb_loader
elif mode == 'L':
self.loader = l_loader
self.classwise_indices = defaultdict(list)
self.idx2class = []
for [_, target] in self.imgs:
# print(target)
self.idx2class.append(target)
# self.idx2class=[target for i,[_,target] in enumerate(self.imgs)]
for idx, classes in enumerate(self.idx2class):
self.classwise_indices[classes].append(idx)
# def get_psd_class(self,idx):
# return self.idx2class[idx]
def set_psd_class(self,pred):
self.idx2class=list(pred.reshape(-1)) # numpy to list the pred is not shuffleed
self.classwise_indices = defaultdict(list)
for idx, classes in enumerate(self.idx2class):
self.classwise_indices[classes].append(idx)
def get_class(self, idx):
return self.idx2class[idx]
def __getitem__(self, index):
path, target = self.imgs[index]
path=os.path.join(self.root,path)
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index
def __len__(self):
return len(self.imgs)
class PairBatchSampler(Sampler):
def __init__(self, dataset, batch_size, num_iterations=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_iterations = num_iterations
def __iter__(self):
indices = list(range(len(self.dataset)))
random.shuffle(indices)
for k in range(len(self)):
if self.num_iterations is None:
offset = k*self.batch_size
batch_indices = indices[offset:offset+self.batch_size]
else:
batch_indices = random.sample(range(len(self.dataset)),
self.batch_size)
pair_indices = []
for idx in batch_indices:
y = self.dataset.get_class(idx)
pair_indices.append(random.choice(self.dataset.classwise_indices[y]))
# print(len(batch_indices + pair_indices))
# print("1", batch_indices, "s", pair_indices)
yield batch_indices + pair_indices
def __len__(self):
if self.num_iterations is None:
return (len(self.dataset)) // self.batch_size
else:
return self.num_iterations
\ No newline at end of file
import sys
# sys.path.append('../data')
from data_pro.unaligned_data_loader import UnalignedDataLoader
from data_pro.svhn import load_svhn
from data_pro.mnist import load_mnist
from data_pro.mnist_m import load_mnistm
from data_pro.usps_ import load_usps
from data_pro.gtsrb import load_gtsrb
from data_pro.synth_number import load_syn
from data_pro.synth_traffic import load_syntraffic
def return_dataset(data, scale=False, usps=False, all_use='no'):
if data == 'svhn':
train_image, train_label, \
test_image, test_label = load_svhn()
if data == 'mnist':
train_image, train_label, \
test_image, test_label = load_mnist()
#print(train_image.shape)
if data == 'mnistm':
train_image, train_label, \
test_image, test_label = load_mnistm()
#print(train_image.shape)
if data == 'usps':
train_image, train_label, \
test_image, test_label = load_usps()
if data == 'synth':
train_image, train_label, \
test_image, test_label = load_syntraffic()
if data == 'gtsrb':
train_image, train_label, \
test_image, test_label = load_gtsrb()
if data == 'syn':
train_image, train_label, \
test_image, test_label = load_syn()
return train_image, train_label, test_image, test_label
def dataset_read(target, batch_size):
S1 = {}
S1_test = {}
S2 = {}
S2_test = {}
S3 = {}
S3_test = {}
S4 = {}
S4_test = {}
S = [S1, S2, S3, S4]
S_test = [S1_test, S2_test, S3_test, S4_test]
T = {}
T_test = {}
domain_all = ['mnistm', 'mnist', 'usps', 'svhn', 'syn']
domain_all.remove(target)
target_train, target_train_label , target_test, target_test_label = return_dataset(target)
print(domain_all)
for i in range(len(domain_all)):
source_train, source_train_label, source_test , source_test_label = return_dataset(domain_all[i])
S[i]['imgs'] = source_train
S[i]['labels'] = source_train_label
#input target sample when test, source performance is not important
S_test[i]['imgs'] = target_test
S_test[i]['labels'] = target_test_label
#S['imgs'] = train_source
#S['labels'] = s_label_train
T['imgs'] = target_train
T['labels'] = target_train_label
# input target samples for both
#S_test['imgs'] = test_target
#S_test['labels'] = t_label_test
T_test['imgs'] = target_test
T_test['labels'] = target_test_label
scale = 32
train_loader = UnalignedDataLoader()
train_loader.initialize(S, T, batch_size, batch_size, scale=scale)
dataset = train_loader.load_data()
test_loader = UnalignedDataLoader()
test_loader.initialize(S_test, T_test, batch_size, batch_size, scale=scale)
dataset_test = test_loader.load_data()
return dataset, dataset_test
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import numpy as np
class Dataset(data.Dataset):
"""Args:
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(self, data, label,
transform=None,target_transform=None):
self.transform = transform
self.target_transform = target_transform
self.data = data
self.labels = label
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
# print(img.shape)
if img.shape[0] != 1:
#print(img)
img = Image.fromarray(np.uint8(np.asarray(img.transpose((1, 2, 0)))))
elif img.shape[0] == 1:
im = np.uint8(np.asarray(img))
#print(np.vstack([im,im,im]).shape)
im = np.vstack([im, im, im]).transpose((1, 2, 0))
img = Image.fromarray(im)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transform is not None:
img = self.transform(img)
# return img, target
return img, target
def __len__(self):
return len(self.data)
import pickle as p
import numpy as np
import os
from PIL import Image
def load_CIFAR_batch(filename):
""" 载入cifar数据集的一个batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" 载入cifar全部数据 """
xs = []
ys = []
for b in range(1,6):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X) #将所有batch整合起来
ys.append(Y)
Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
def split_ssl_data( data, target, num_labels, num_classes, index=None, include_lb_to_ulb=True):
"""
data & target is splitted into labeled and unlabeld data.
Args
index: If np.array of index is given, select the data[index], target[index] as labeled samples.
include_lb_to_ulb: If True, labeled data is also included in unlabeld data
"""
data, target = np.array(data), np.array(target)
lb_data, lbs, lb_idx, = sample_labeled_data(data, target, num_labels, num_classes, index)
ulb_idx = np.array(sorted(list(set(range(len(data))) - set(lb_idx)))) # unlabeled_data index of data
if include_lb_to_ulb:
return lb_data, lbs, data, target
else:
return lb_data, lbs, data[ulb_idx], target[ulb_idx]
def sample_labeled_data(data, target,
num_labels, num_classes,
index=None, name=None):
'''
samples for labeled data
(sampling with balanced ratio over classes)
'''
assert num_labels % num_classes == 0
if not index is None:
index = np.array(index, dtype=np.int32)
return data[index], target[index], index
# dump_path = os.path.join(args.save_dir, args.save_name, 'sampled_label_idx.npy')
# if os.path.exists(dump_path):
# lb_idx = np.load(dump_path)
# lb_data = data[lb_idx]
# lbs = target[lb_idx]
# return lb_data, lbs, lb_idx
samples_per_class = int(num_labels / num_classes)
lb_data = []
lbs = []
lb_idx = []
np.random.seed(2022)
for c in range(num_classes):
idx = np.where(target == c)[0]
idx = np.random.choice(idx, samples_per_class, False)
lb_idx.extend(idx)
lb_data.extend(data[idx])
lbs.extend(target[idx])
# np.save(dump_path, np.array(lb_idx))
# np.save(dump_path, np.array(lb_idx))
return np.array(lb_data), np.array(lbs), np.array(lb_idx)
def mysave(dataset,lb_data,lbs, txt_path, ROOT,cnt):
# cnt=0
lines=[]
with open(txt_path,"w") as f:
isfirst=True
for (img,label) in zip(lb_data,lbs):
if not os.path.exists(os.path.join(ROOT,str(label))):
os.makedirs(os.path.join(ROOT,str(label)))
# print(type(img),img.shape)
image=Image.fromarray(img.astype(np.uint8))
image=image.convert('RGB')
image.save(os.path.join(ROOT,str(label),"{}.jpg".format(cnt)))
if isfirst:
isfirst=False
f.writelines(["{}/{}/{}.jpg {}".format(dataset,label,cnt,label)])
else :
f.writelines(["\n{}/{}/{}.jpg {}".format(dataset,label,cnt,label)])
cnt+=1
# f.writelines(lines)
return cnt
ROOT="/data/maning/datasets/cifar-10-batches-py/"
Xtr, Ytr, Xte, Yte=load_CIFAR10(ROOT)
print(Xtr.shape,Ytr.shape,Xte.shape,Yte.shape)
num_classes=10
dataset="CIFAR10_4(1)"
num_labels=num_classes*4
lb_data, lbs, data, target=split_ssl_data( Xtr, Ytr, num_labels, num_classes, index=None, include_lb_to_ulb=True)
p=os.path.join(ROOT,"SSL",dataset)
import shutil
if os.path.exists(p):
shutil.rmtree(p)
cnt=0
cnt1=mysave(dataset,lb_data,lbs,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","labeled_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt)
cnt2=mysave(dataset,data,target,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","unlabeled_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt1)
cnt3=mysave(dataset,Xte,Yte, os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","validation_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt2)
print(cnt1,cnt2,cnt3)
# np.save("/data/maning/datasets/cifar-10-batches-py/labeled_data.npy", np.array(lb_data))
# np.save("/data/maning/datasets/cifar-10-batches-py/labeled_label.npy", np.array(lbs))
# np.save("/data/maning/datasets/cifar-10-batches-py/unlabeled_data.npy", np.array(data))
# np.save("/data/maning/datasets/cifar-10-batches-py/unlabeled_label.npy", np.array(target))
# np.save("/data/maning/datasets/cifar-10-batches-py/test_data.npy", np.array(Xte))
# np.save("/data/maning/datasets/cifar-10-batches-py/test_label.npy", np.array(Yte))
import pickle as p
import numpy as np
import os
from PIL import Image
def load_CIFAR_batch(filename):
""" 载入cifar数据集的一个batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y
def unpickle(file):
fo = open(file, 'rb')
dict = p.load(fo,encoding='latin1')
fo.close()
return dict
def load_CIFAR100(ROOT):
""" 载入cifar全部数据 """
path = os.path.join(ROOT, "train")
batch = unpickle(path)
# print(batch['fine_labels'] )
Xtr = np.array(batch['data'])
Ytr = np.array(batch['fine_labels'] )
path = os.path.join(ROOT, "test")
batch = unpickle(path)
Xte =np.array( batch['data'])
Yte = np.array(batch['fine_labels'] )
# print(Xte.size(),Yte.size())
return Xtr, Ytr, Xte, Yte
def split_ssl_data( data, target, num_labels, num_classes, index=None, include_lb_to_ulb=True):
"""
data & target is splitted into labeled and unlabeld data.
Args
index: If np.array of index is given, select the data[index], target[index] as labeled samples.
include_lb_to_ulb: If True, labeled data is also included in unlabeld data
"""
data, target = np.array(data), np.array(target)
lb_data, lbs, lb_idx, = sample_labeled_data(data, target, num_labels, num_classes, index)
ulb_idx = np.array(sorted(list(set(range(len(data))) - set(lb_idx)))) # unlabeled_data index of data
if include_lb_to_ulb:
return lb_data, lbs, data, target
else:
return lb_data, lbs, data[ulb_idx], target[ulb_idx]
def sample_labeled_data(data, target,
num_labels, num_classes,
index=None, name=None):
'''
samples for labeled data
(sampling with balanced ratio over classes)
'''
assert num_labels % num_classes == 0
if not index is None:
index = np.array(index, dtype=np.int32)
return data[index], target[index], index
# dump_path = os.path.join(args.save_dir, args.save_name, 'sampled_label_idx.npy')
# if os.path.exists(dump_path):
# lb_idx = np.load(dump_path)
# lb_data = data[lb_idx]
# lbs = target[lb_idx]
# return lb_data, lbs, lb_idx
samples_per_class = int(num_labels / num_classes)
lb_data = []
lbs = []
lb_idx = []
np.random.seed(2022)
for c in range(num_classes):
idx = np.where(target == c)[0]
idx = np.random.choice(idx, samples_per_class, False)
lb_idx.extend(idx)
lb_data.extend(data[idx])
lbs.extend(target[idx])
# np.save(dump_path, np.array(lb_idx))
# np.save(dump_path, np.array(lb_idx))
return np.array(lb_data), np.array(lbs), np.array(lb_idx)
def mysave(dataset,lb_data,lbs, txt_path, ROOT,cnt):
# cnt=0
lines=[]
with open(txt_path,"w") as f:
isfirst=True
for (img,label) in zip(lb_data,lbs):
if not os.path.exists(os.path.join(ROOT,str(label))):
os.makedirs(os.path.join(ROOT,str(label)))
image=Image.fromarray(img.reshape(3,32,32).transpose(1, 2, 0))
image=image.convert('RGB')
image.save(os.path.join(ROOT,str(label),"{}.jpg".format(cnt)))
if isfirst:
isfirst=False
f.writelines(["{}/{}/{}.jpg {}".format(dataset,label,cnt,label)])
else :
f.writelines(["\n{}/{}/{}.jpg {}".format(dataset,label,cnt,label)])
cnt+=1
# f.writelines(lines)
return cnt
ROOT="/data/maning/datasets/cifar-100-python/"
target_path="/data/maning/git/shot/data"
Xtr, Ytr, Xte, Yte=load_CIFAR100(ROOT)
print(Xtr.shape,Ytr.shape,Xte.shape,Yte.shape)
num_classes=100
dataset="CIFAR100_4"
num_labels=num_classes*4
lb_data, lbs, data, target=split_ssl_data( Xtr, Ytr, num_labels, num_classes, index=None, include_lb_to_ulb=True)
p=os.path.join(target_path,"SSL",dataset)
import shutil
if os.path.exists(p):
shutil.rmtree(p)
cnt=0
cnt1=mysave(dataset,lb_data,lbs,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","labeled_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt)
cnt2=mysave(dataset,data,target,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","unlabeled_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt1)
cnt3=mysave(dataset,Xte,Yte, os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","validation_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt2)
# os.path.join(
shutil.copy(os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","validation_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes)))
,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","labeled_source_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))
))
print(cnt1,cnt2,cnt3)
# np.save("/data/maning/datasets/cifar-10-batches-py/labeled_data.npy", np.array(lb_data))
# np.save("/data/maning/datasets/cifar-10-batches-py/labeled_label.npy", np.array(lbs))
# np.save("/data/maning/datasets/cifar-10-batches-py/unlabeled_data.npy", np.array(data))
# np.save("/data/maning/datasets/cifar-10-batches-py/unlabeled_label.npy", np.array(target))
# np.save("/data/maning/datasets/cifar-10-batches-py/test_data.npy", np.array(Xte))
# np.save("/data/maning/datasets/cifar-10-batches-py/test_label.npy", np.array(Yte))
import numpy as np
import pickle as pkl
def load_gtsrb():
data_target = pkl.load(open('../data/data_gtsrb'))
target_train = np.random.permutation(len(data_target['image']))
data_t_im = data_target['image'][target_train[:31367], :, :, :]
data_t_im_test = data_target['image'][target_train[31367:], :, :, :]
data_t_label = data_target['label'][target_train[:31367]] + 1
data_t_label_test = data_target['label'][target_train[31367:]] + 1
data_t_im = data_t_im.transpose(0, 3, 1, 2).astype(np.float32)
data_t_im_test = data_t_im_test.transpose(0, 3, 1, 2).astype(np.float32)
return data_t_im, data_t_label, data_t_im_test, data_t_label_test
import os
import torch
import numpy as np
from torchvision import transforms
from data_pro.SSDA_data_list import Imagelists_VISDA, return_classlist
from data_pro.data_list import ImageList_idx,ImageList
import collections
import torch
class ResizeImage():
def __init__(self, size):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
th, tw = self.size
return img.resize((th, tw))
class TransformTwice:
def __init__(self, transform):
self.transform = transform
def __call__(self, inp):
out1 = self.transform(inp)
out2 = self.transform(inp)
return out1, out2
def return_dataset(args):
base_path = '/data/maning/git/shot/data/SSDA_split/%s' % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = '/data/maning/git/shot/data/OfficeHomeDataset/'
else :
root = '/data/maning/git/shot/data/%s/' % args.dataset
image_set_file_s = \
os.path.join(base_path,
'labeled_source_images_' +
args.s + '.txt')
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
def split(train_r, source_path):
with open(source_path, 'r') as f:
data = f.readlines()
train_len = int(len(data) * train_r)
train, val = torch.utils.data.random_split(data, [train_len, len(data) - train_len])
return train, val
if args.dataset in "multi":
source_train, source_val = split(train_r=0.95, source_path=image_set_file_s)
else:
source_train, source_val = split(train_r=0.90, source_path=image_set_file_s)
print("source_train and val num", len(source_train), len(source_val))
source_dataset = Imagelists_VISDA(source_train, root=root,
transform=data_transforms['train'])
source_val_dataset = Imagelists_VISDA(source_val, root=root,
transform=data_transforms['val'])
target_dataset = Imagelists_VISDA(open(image_set_file_t).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_val = Imagelists_VISDA(open(image_set_file_t_val).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(open(image_set_file_unl).readlines(), root=root,
transform=TransformTwice(data_transforms['val']))
# target_dataset_unl.return_index = True
target_dataset_test = Imagelists_VISDA(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['test'])
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
# if args.net == 'alexnet':
# bs = 20
# else:
# bs = 16
bs=args.batch_size
source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs,
# pin_memory=True,
num_workers=args.worker, shuffle=True,
drop_last=False)
source_val_loader = torch.utils.data.DataLoader(source_val_dataset, batch_size=bs,
# pin_memory=True,
num_workers=args.worker, shuffle=False,
drop_last=False)
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_val = \
torch.utils.data.DataLoader(target_dataset_val,
batch_size=min(bs,
len(target_dataset_val)),
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=False)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
# pin_memory=True,
shuffle=True, drop_last=True)
target_loader_test = \
torch.utils.data.DataLoader(target_dataset_test,
batch_size=bs * 2, num_workers=args.worker,
# pin_memory=True,
shuffle=False, drop_last=False)
return source_loader, source_val_loader, target_loader, target_loader_unl, \
target_loader_val, target_loader_test, (target_dataset,target_dataset_unl)
def return_dataloader_by_entropy(args,netF,netC,unlabeled_data_loader):
netF.eval()
netC.eval()
base_path = '/data/maning/git/shot/data/SSDA_split/%s' % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = '/data/maning/git/shot/data/OfficeHomeDataset/'
else:
root = '/data/maning/git/shot/data/%s/' % args.dataset
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
bs = args.batch_size
unlabel_target_list = open(image_set_file_unl).readlines()
target_list = open(image_set_file_t).readlines()
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl.return_index=True
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
start_test = True
with torch.no_grad():
all_index=[]
for step, (inputs, labels, index) in enumerate(target_loader_unl):
# = data[0]
# labels = data[1]
inputs = inputs.cuda()
# labels = labels.cuda() # 2020 07 06
# inputs = inputs
outputs = netC(netF(inputs))
# outputs,margin_logits = netC(netF(inputs),labels)
# labels = labels.cpu() # 2020 07 06
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_index=index.int()
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_index= torch.cat((all_index, index.int()), 0)
# print(type(index))
# all_index.append(index)
print("len",len(all_index),len(all_output))
_, pred_label = torch.max(all_output, 1)
# _, predict = torch.max(all_output, 1)
# res = torch.squeeze(predict).float() == all_label
# accuracy = torch.sum(res).item() / float(all_label.size()[0])
entropy = Entropy(torch.nn.Softmax(dim=1)(all_output))
# mean_ent = torch.mean(entropy).cpu().data.item()
# entropy for each class
class_to_entropy = list(0. for i in range(args.class_num))
# correct = list(0. for i in range(args.class_num))
total_add = list(0. for i in range(args.class_num))
class2index=collections.defaultdict(list)
classforentropy=collections.defaultdict(list)
# kk = list(0. for i in range(args.class_num))
# for label_idx in range(len(all_label)):
# label_single = int(all_label[label_idx].item())
# correct[label_single] += res[label_idx].item()
# total[label_single] += 1
# for label in range(126):
# kk[label] = 1.0 * correct[label] / total[label]
for label in range(args.class_num):
class_to_entropy[label] = torch.mean(entropy[pred_label == label]).item()
classforentropy[label]=list(entropy[pred_label == label].numpy())
for index in all_index:
label=pred_label[index].item()
if entropy[index] < class_to_entropy[label]*0.5:
class2index[label].append(index)
# print(class2index.items())
# assert
# for index , label in enumerate(all_label):
# print(label,unlabel_target_list[index])
max_num_per_class=5
line2remove=[]
for psudo_label, indexes in class2index.items():
for index in indexes:
line= unlabel_target_list[index]
psudo_line=line.split(" ")[0] + " " + str(psudo_label)
target_list.append(psudo_line)
line2remove.append(line)
total_add[int(line.split(" ")[1])] += 1
if total_add[int(line.split(" ")[1])] >= max_num_per_class:
break
print("added number for each class ",sum(total_add),total_add)
before_remove=len(unlabel_target_list)
for line in line2remove :
unlabel_target_list.remove(line)
print("removed data from unlalbed dataset",before_remove," to ",len(unlabel_target_list))
# print(target_list)
target_dataset = Imagelists_VISDA(target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
# pin_memory=True,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
netF.eval()
netC.eval()
return target_loader, target_loader_unl
def return_dataloader_by_topK_entropy(args,netF,netC,unlabeled_data_loader):
netF.eval()
netC.eval()
base_path = '/data/maning/git/shot/data/SSDA_split/%s' % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = '/data/maning/git/shot/data/OfficeHomeDataset/'
else:
root = '/data/maning/git/shot/data/%s/' % args.dataset
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
bs = args.batch_size
unlabel_target_list = open(image_set_file_unl).readlines()
target_list = open(image_set_file_t).readlines()
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=data_transforms['test'])
target_dataset_unl.return_index=True
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
start_test = True
with torch.no_grad():
# all_index=[]
for step, (inputs, labels,index) in enumerate(target_loader_unl):
# = data[0]
# labels = data[1]
inputs = inputs.cuda()
# labels = labels.cuda() # 2020 07 06
# inputs = inputs
outputs = netC(netF(inputs))
# outputs,margin_logits = netC(netF(inputs),labels)
# labels = labels.cpu() # 2020 07 06
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_index=index
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_index= torch.cat((all_index, index), 0)
# print(type(index))
# all_index.append(index)
# print("len",len(all_index),len(all_output))
_, pred_label = torch.max(all_output, 1)
# _, predict = torch.max(all_output, 1)
# res = torch.squeeze(predict).float() == all_label
# accuracy = torch.sum(res).item() / float(all_label.size()[0])
entropy = Entropy(torch.nn.Softmax(dim=1)(all_output))
# mean_ent = torch.mean(entropy).cpu().data.item()
# entropy for each class
class_to_mean_entropy = list(0. for i in range(args.class_num))
# correct = list(0. for i in range(args.class_num))
total_add = list(0. for i in range(args.class_num))
class2index=collections.defaultdict(list)
classforentropy=collections.defaultdict(list)
index2all_index=collections.defaultdict(dict)
max_num_per_class=args.max_num_per_class
for label in range(args.class_num):
class_to_mean_entropy[label] = torch.mean(entropy[pred_label == label]).item()
# classforentropy[label]=list(entropy[pred_label == label].numpy())
index2all_index[label]=dict(list(zip(range(len(all_index[pred_label == label])),all_index[pred_label == label].numpy())))
# print("index2all_label",index2all_index[label])
classentropy=entropy[pred_label == label]
if len(classentropy)<max_num_per_class:
indexes=range(len(classentropy))
else:
preds, indexes = torch.topk(classentropy, max_num_per_class, largest=False)
indexes=indexes.numpy()
# print("entroy_len",len(entropy[pred_label == label]),"index_len",len(all_index[pred_label == label]))
# print("indexes",indexes)
for ind in indexes:
class2index[label].append(index2all_index[label][ind]) # local index to global index
# for index in all_index:
# label=pred_label[index].item()
# if entropy[index] < class_to_mean_entropy[label]*0.5:
# class2index[label].append(index)
#
# # print(class2index.items())
#
#
# # assert
# # for index , label in enumerate(all_label):
# # print(label,unlabel_target_list[index])
#
total=0
acc=0
line2remove=[]
for psudo_label, indexes in class2index.items():
for index in indexes:
line= unlabel_target_list[index]
psudo_line=line.split(" ")[0] + " " + str(psudo_label)
target_list.append(psudo_line)
line2remove.append(line)
total_add[int(line.split(" ")[1])] += 1
# if total_add[int(line.split(" ")[1])] >= max_num_per_class:
# break
# print(line.split(" ")[1],str(psudo_label))
if int(line.split(" ")[1])==psudo_label:
acc+=1
total+=1
print("acc of psudo label",acc*1.0/total)
# print("added number for each class ",sum(total_add),total_add)
before_remove=len(unlabel_target_list)
for line in line2remove :
unlabel_target_list.remove(line)
# print("removed data from unlalbed dataset",before_remove," to ",len(unlabel_target_list))
# print(target_list)
target_dataset = Imagelists_VISDA(target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
# pin_memory=True,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
netF.train()
netC.train()
return target_loader, target_loader_unl
def apply_train_dropout(m):
if type(m) == torch.nn.Dropout:
# print("find droput")
m.train()
def return_dataloader_by_UPS(args,netF,netC,unlabeled_data_loader):
netF.eval()
netC.eval()
base_path = '/data/maning/git/shot/data/SSDA_split/%s' % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = '/data/maning/git/shot/data/OfficeHomeDataset/'
else:
root = '/data/maning/git/shot/data/%s/' % args.dataset
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
if args.uda==1:
image_set_file_unl = \
os.path.join(base_path,
'labeled_source_images_' +
args.t+".txt")
else:
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
netF.apply(apply_train_dropout)# MC dropout
netC.apply(apply_train_dropout)
bs = args.batch_size
unlabel_target_list = open(image_set_file_unl).readlines()
target_list = open(image_set_file_t).readlines()
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_dataset_unl.return_index=True
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
start_test = True
with torch.no_grad():
all_index=[]
for step, ((inputs1,inputs2), labels,index) in enumerate(target_loader_unl):
# print(step,index)
# = data[0]
# labels = data[1]
inputs1,inputs2 = inputs1.cuda(),inputs2.cuda()
# labels = labels.cuda() # 2020 07 06
# inputs = inputs
output=[]
batch_s=inputs2.shape[0]
repeat=5
for i in range(repeat):
outputs1 = netC(netF(inputs1)).cpu()
outputs2 = netC(netF(inputs2)).cpu()
output.append(outputs1)
output.append(outputs2)
output=torch.cat(output,dim=0).view(2*repeat,batch_s,-1)
# print("std",torch.std(output,dim=0))
outputs=torch.mean(output,dim=0)
# print("outputs",outputs.shape)
# print("mean",outputs)
# outputs,margin_logits = netC(netF(inputs),labels)
# labels = labels.cpu() # 2020 07 06
# step=torch.tensor(step).int().reshape(-1)
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_index= index
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_index= torch.cat((all_index, index), 0)
# print(type(index))
# all_index.append(index)
# print("len",len(all_index),len(all_output))
_, pred_label = torch.max(all_output, 1)
# _, predict = torch.max(all_output, 1)
# res = torch.squeeze(predict).float() == all_label
# accuracy = torch.sum(res).item() / float(all_label.size()[0])
entropy = Entropy(torch.nn.Softmax(dim=1)(all_output))
# mean_ent = torch.mean(entropy).cpu().data.item()
# entropy for each class
class_to_mean_entropy = list(0. for i in range(args.class_num))
# correct = list(0. for i in range(args.class_num))
total_add = list(0. for i in range(args.class_num))
class2index=collections.defaultdict(list)
classforentropy=collections.defaultdict(list)
index2all_index=collections.defaultdict(dict)
max_num_per_class=args.max_num_per_class
# print(all_index.shape,pred_label.shape)
for label in range(args.class_num):
class_to_mean_entropy[label] = torch.mean(entropy[pred_label == label]).item()
# classforentropy[label]=list(entropy[pred_label == label].numpy())
index2all_index[label]=dict(list(zip(range(len(all_index[pred_label == label])),all_index[pred_label == label].numpy())))
# print("index2all_label",index2all_index[label])
classentropy=entropy[pred_label == label]
if len(classentropy)<max_num_per_class:
indexes=range(len(classentropy))
else:
preds, indexes = torch.topk(classentropy, max_num_per_class, largest=False)
indexes=indexes.numpy()
# print("entroy_len",len(entropy[pred_label == label]),"index_len",len(all_index[pred_label == label]))
# print("indexes",indexes)
for ind in indexes:
class2index[label].append(index2all_index[label][ind]) # local index to global index
# for index in all_index:
# label=pred_label[index].item()
# if entropy[index] < class_to_mean_entropy[label]*0.5:
# class2index[label].append(index)
#
# # print(class2index.items())
#
#
# # assert
# # for index , label in enumerate(all_label):
# # print(label,unlabel_target_list[index])
#
total=0
acc=0
line2remove=[]
if args.uda == 1:
target_list=[]
for psudo_label, indexes in class2index.items():
for index in indexes:
line= unlabel_target_list[index]
psudo_line=line.split(" ")[0] + " " + str(psudo_label)
target_list.append(psudo_line)
line2remove.append(line)
total_add[int(line.split(" ")[1])] += 1
# if total_add[int(line.split(" ")[1])] >= max_num_per_class:
# break
# print(line.split(" ")[1],str(psudo_label))
if int(line.split(" ")[1])==psudo_label:
acc+=1
total+=1
# print("acc of psudo label",acc*1.0/total)
# print("added number for each class ",sum(total_add),total_add)
before_remove=len(unlabel_target_list)
for line in line2remove :
unlabel_target_list.remove(line)
# print("removed data from unlalbed dataset",before_remove," to ",len(unlabel_target_list))
# print(target_list)
target_dataset = Imagelists_VISDA(target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
# pin_memory=True,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
netF.train()
netC.train()
return target_loader, target_loader_unl,acc*1.0/total
def return_dataloader_by_progressive_UPS(args,netF,netC,unlabeled_data_loader):
netF.eval()
netC.eval()
base_path = '/data/maning/git/shot/data/SSDA_split/%s' % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = '/data/maning/git/shot/data/OfficeHomeDataset/'
else:
root = '/data/maning/git/shot/data/%s/' % args.dataset
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
if args.uda==1:
image_set_file_unl = \
os.path.join(base_path,
'labeled_source_images_' +
args.t+".txt")
else:
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
netF.apply(apply_train_dropout)# MC dropout
netC.apply(apply_train_dropout)
bs = args.batch_size
unlabel_target_list = open(image_set_file_unl).readlines()
target_list = open(image_set_file_t).readlines()
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_dataset_unl.return_index=True
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
start_test = True
with torch.no_grad():
all_index=[]
for step, ((inputs1,inputs2), labels,index) in enumerate(target_loader_unl):
# print(step,index)
# = data[0]
# labels = data[1]
inputs1,inputs2 = inputs1.cuda(),inputs2.cuda()
# labels = labels.cuda() # 2020 07 06
# inputs = inputs
output=[]
batch_s=inputs2.shape[0]
repeat=5
for i in range(repeat):
outputs1 = netC(netF(inputs1)).cpu()
outputs2 = netC(netF(inputs2)).cpu()
output.append(outputs1)
output.append(outputs2)
output=torch.cat(output,dim=0).view(2*repeat,batch_s,-1)
# print("std",torch.std(output,dim=0))
outputs=torch.mean(output,dim=0)
# print("outputs",outputs.shape)
# print("mean",outputs)
# outputs,margin_logits = netC(netF(inputs),labels)
# labels = labels.cpu() # 2020 07 06
# step=torch.tensor(step).int().reshape(-1)
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_index= index
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_index= torch.cat((all_index, index), 0)
# print(type(index))
# all_index.append(index)
# print("len",len(all_index),len(all_output))
_, pred_label = torch.max(all_output, 1)
# _, predict = torch.max(all_output, 1)
# res = torch.squeeze(predict).float() == all_label
# accuracy = torch.sum(res).item() / float(all_label.size()[0])
entropy = Entropy(torch.nn.Softmax(dim=1)(all_output))
# mean_ent = torch.mean(entropy).cpu().data.item()
# entropy for each class
class_to_mean_entropy = list(4 for i in range(args.class_num))
# correct = list(0. for i in range(args.class_num))
total_add = list(0. for i in range(args.class_num))
class2index=collections.defaultdict(list)
classforentropy=collections.defaultdict(list)
index2all_index=collections.defaultdict(dict)
# print(all_index.shape,pred_label.shape)
for label in range(args.class_num):
#none prediction belongs to lable class
if len(entropy[pred_label == label])==0:
_, ind=torch.topk(all_output[:,label], max_num_per_class, largest=True)
pred_label[ind]=label
# class_to_mean_entropy[label] = 6
# print("ssssssssss***********************************")
class_to_mean_entropy[label] = torch.mean(entropy[pred_label == label]).item()
# print("class to entropy",class_to_mean_entropy[label])
# classforentropy[label]=list(entropy[pred_label == label].numpy())
if args.cur_epoch>0:
max_num_per_class=int(args.max_num_per_class*np.exp(2/class_to_mean_entropy[label]))
max_num_per_class=min(int(len(all_output)/args.class_num*0.5), max_num_per_class)
else:
max_num_per_class=args.max_num_per_class
# print("numper classes ", max_num_per_class)
index2all_index[label]=dict(list(zip(range(len(all_index[pred_label == label])),all_index[pred_label == label].numpy())))
classentropy=entropy[pred_label == label]
if len(classentropy)<max_num_per_class:
indexes=range(len(classentropy))
else:
preds, indexes = torch.topk(classentropy, max_num_per_class, largest=False)
indexes=indexes.numpy()
# print("entroy_len",len(entropy[pred_label == label]),"index_len",len(all_index[pred_label == label]))
# print("indexes",indexes)
for ind in indexes:
class2index[label].append(index2all_index[label][ind]) # local index to global index
# for index in all_index:
# label=pred_label[index].item()
# if entropy[index] < class_to_mean_entropy[label]*0.5:
# class2index[label].append(index)
#
# # print(class2index.items())
#
#
# # assert
# # for index , label in enumerate(all_label):
# # print(label,unlabel_target_list[index])
#
total=0
acc=0
line2remove=[]
if args.uda == 1:
target_list=[]
for psudo_label, indexes in class2index.items():
for index in indexes:
line= unlabel_target_list[index]
psudo_line=line.split(" ")[0] + " " + str(psudo_label)
target_list.append(psudo_line)
line2remove.append(line)
total_add[int(line.split(" ")[1])] += 1
# if total_add[int(line.split(" ")[1])] >= max_num_per_class:
# break
# print(line.split(" ")[1],str(psudo_label))
if int(line.split(" ")[1])==psudo_label:
acc+=1
total+=1
# print("acc of psudo label",acc*1.0/total)
# print("added number for each class ",sum(total_add),total_add)
before_remove=len(unlabel_target_list)
for line in line2remove :
if line in unlabel_target_list:
unlabel_target_list.remove(line)
# print("removed data from unlalbed dataset",before_remove," to ",len(unlabel_target_list))
# print(target_list)
target_dataset = Imagelists_VISDA(target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
# pin_memory=True,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
netF.train()
netC.train()
return target_loader, target_loader_unl,acc*1.0/total
def return_dataloader_by_threshhold(args,netF,netC,unlabeled_data_loader):
netF.eval()
netC.eval()
base_path = '/data/maning/git/shot/data/SSDA_split/%s' % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = '/data/maning/git/shot/data/OfficeHomeDataset/'
else:
root = '/data/maning/git/shot/data/%s/' % args.dataset
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
bs = args.batch_size
unlabel_target_list = open(image_set_file_unl).readlines()
target_list = open(image_set_file_t).readlines()
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_dataset_unl.return_index=True
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
start_test = True
with torch.no_grad():
all_index=[]
for step, ((inputs1,inputs2), labels,index) in enumerate(target_loader_unl):
# print(step,index)
# = data[0]
# labels = data[1]
inputs1,inputs2 = inputs1.cuda(),inputs2.cuda()
# labels = labels.cuda() # 2020 07 06
# inputs = inputs
output=[]
batch_s=inputs2.shape[0]
repeat=5
for i in range(repeat):
outputs1 = netC(netF(inputs1)).cpu()
outputs2 = netC(netF(inputs2)).cpu()
output.append(outputs1)
output.append(outputs2)
output=torch.cat(output,dim=0).view(2*repeat,batch_s,-1)
# print("std",torch.std(output,dim=0))
outputs=torch.mean(output,dim=0)
# print("outputs",outputs.shape)
# print("mean",outputs)
# outputs,margin_logits = netC(netF(inputs),labels)
# labels = labels.cpu() # 2020 07 06
# step=torch.tensor(step).int().reshape(-1)
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_index= index
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_index= torch.cat((all_index, index), 0)
# print(type(index))
# all_index.append(index)
# print("len",len(all_index),len(all_output))
_, pred_label = torch.max(all_output, 1)
# _, predict = torch.max(all_output, 1)
# res = torch.squeeze(predict).float() == all_label
# accuracy = torch.sum(res).item() / float(all_label.size()[0])
softmax = torch.nn.Softmax(dim=1)(all_output)
prob,_ = torch.max(softmax,dim=1,keepdim=True)
threshold=0.7
# mean_ent = torch.mean(entropy).cpu().data.item()
# entropy for each class
# class_to_mean_probobility = list(0. for i in range(args.class_num))
# correct = list(0. for i in range(args.class_num))
total_add = list(0. for i in range(args.class_num))
class2index=collections.defaultdict(list)
classforentropy=collections.defaultdict(list)
index2all_index=collections.defaultdict(dict)
max_num_per_class=args.max_num_per_class
# print(all_index.shape,pred_label.shape)
for label in range(args.class_num):
# class_to_mean_probobility[label] = prob[pred_label == label].item()
# classforentropy[label]=list(entropy[pred_label == label].numpy())
# print("index2all_label",index2all_index[label])
# print(type(pred_label),type(label),type(prob),type(threshold))
condition= ((pred_label.reshape(-1,1) == label) & (prob.reshape(-1,1) > threshold)).view(-1)
# print(condition.shape)
class_prob=prob[condition].view(-1)
# class_prob=class_prob
index2all_index[label] = dict(
list(zip(range(len(all_index[condition])), all_index[condition].numpy())))
if len(class_prob)<=max_num_per_class:
indexes=range(len(class_prob))
else:
# print(len(class_prob),class_prob.shape, max_num_per_class)
_, indexes = torch.topk(class_prob, max_num_per_class, largest=True)
indexes=indexes.numpy()
# print("entroy_len",len(entropy[pred_label == label]),"index_len",len(all_index[pred_label == label]))
# print("indexes",indexes)
for ind in indexes:
class2index[label].append(index2all_index[label][ind]) # local index to global index
# for index in all_index:
# label=pred_label[index].item()
# if entropy[index] < class_to_mean_entropy[label]*0.5:
# class2index[label].append(index)
#
# # print(class2index.items())
#
#
# # assert
# # for index , label in enumerate(all_label):
# # print(label,unlabel_target_list[index])
#
total=0
acc=0
line2remove=[]
for psudo_label, indexes in class2index.items():
for index in indexes:
line= unlabel_target_list[index]
psudo_line=line.split(" ")[0] + " " + str(psudo_label)
target_list.append(psudo_line)
line2remove.append(line)
total_add[int(line.split(" ")[1])] += 1
# if total_add[int(line.split(" ")[1])] >= max_num_per_class:
# break
# print(line.split(" ")[1],str(psudo_label))
if int(line.split(" ")[1])==psudo_label:
acc+=1
total+=1
print("acc of psudo label",acc*1.0/total)
print("added number for each class ",sum(total_add),total_add)
before_remove=len(unlabel_target_list)
for line in line2remove :
unlabel_target_list.remove(line)
# print("removed data from unlalbed dataset",before_remove," to ",len(unlabel_target_list))
# print(target_list)
target_dataset = Imagelists_VISDA(target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
# pin_memory=True,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
netF.train()
netC.train()
return target_loader, target_loader_unl
def Entropy(input_):
bs = input_.size(0)
epsilon = 1e-5
entropy = -input_ * torch.log(input_ + epsilon)
entropy = torch.sum(entropy, dim=1)
return entropy
def return_dataset_test(args):
base_path = './data/txt/%s' % args.dataset
root = './data/%s/' % args.dataset
image_set_file_s = os.path.join(base_path, args.source + '_all' + '.txt')
image_set_file_test = os.path.join(base_path,
'unlabeled_target_images_' +
args.target + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
target_dataset_unl = Imagelists_VISDA(image_set_file_test, root=root,
transform=data_transforms['test'],
test=True)
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
if args.net == 'alexnet':
bs = 32
else:
bs = 24
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 2, num_workers=3,
shuffle=False, drop_last=False)
return target_loader_unl, class_list
import numpy as np
from scipy.io import loadmat
base_dir = './data'
def load_mnist(scale=True, usps=False, all_use=False):
mnist_data = loadmat(base_dir + '/mnist_data.mat')
if scale:
mnist_train = np.reshape(mnist_data['train_32'], (55000, 32, 32, 1))
mnist_test = np.reshape(mnist_data['test_32'], (10000, 32, 32, 1))
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)
mnist_train = mnist_train.transpose(0, 3, 1, 2).astype(np.float32)
mnist_test = mnist_test.transpose(0, 3, 1, 2).astype(np.float32)
mnist_labels_train = mnist_data['label_train']
mnist_labels_test = mnist_data['label_test']
else:
mnist_train = mnist_data['train_28']
mnist_test = mnist_data['test_28']
mnist_labels_train = mnist_data['label_train']
mnist_labels_test = mnist_data['label_test']
mnist_train = mnist_train.astype(np.float32)
mnist_test = mnist_test.astype(np.float32)
mnist_train = mnist_train.transpose((0, 3, 1, 2))
mnist_test = mnist_test.transpose((0, 3, 1, 2))
train_label = np.argmax(mnist_labels_train, axis=1)
inds = np.random.permutation(mnist_train.shape[0])
mnist_train = mnist_train[inds]
train_label = train_label[inds]
test_label = np.argmax(mnist_labels_test, axis=1)
mnist_train = mnist_train[:25000]
train_label = train_label[:25000]
mnist_test = mnist_test[:25000]
test_label = test_label[:25000]
# print('mnist train X shape->', mnist_train.shape)
# print('mnist train y shape->', train_label.shape)
# print('mnist test X shape->', mnist_test.shape)
# print('mnist test y shape->', test_label.shape)
return mnist_train, train_label, mnist_test, test_label
import numpy as np
from scipy.io import loadmat
base_dir = './data'
def load_mnistm(scale=True, usps=False, all_use=False):
mnistm_data = loadmat(base_dir + '/mnistm_with_label.mat')
mnistm_train = mnistm_data['train']
mnistm_test = mnistm_data['test']
mnistm_train = mnistm_train.transpose(0, 3, 1, 2).astype(np.float32)
mnistm_test = mnistm_test.transpose(0, 3, 1, 2).astype(np.float32)
mnistm_labels_train = mnistm_data['label_train']
mnistm_labels_test = mnistm_data['label_test']
train_label = np.argmax(mnistm_labels_train, axis=1)
inds = np.random.permutation(mnistm_train.shape[0])
mnistm_train = mnistm_train[inds]
train_label = train_label[inds]
test_label = np.argmax(mnistm_labels_test, axis=1)
mnistm_train = mnistm_train[:25000]
train_label = train_label[:25000]
mnistm_test = mnistm_test[:9000]
test_label = test_label[:9000]
# print('mnist_m train X shape->', mnistm_train.shape)
# print('mnist_m train y shape->', train_label.shape)
# print('mnist_m test X shape->', mnistm_test.shape)
# print('mnist_m test y shape->', test_label.shape)
return mnistm_train, train_label, mnistm_test, test_label
"""Data loading facilities for Omniglot experiment."""
import random
import os
from os.path import join
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler
from torch.utils.data import Dataset
from torchvision.datasets.utils import list_dir, list_files
from torchvision import transforms
from PIL import Image
HOLD_OUT = 0
DEFAULT_TRANSFORM =transforms.Compose([transforms.Resize((100,100)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)])
CAT2LABEL={'bike': 0, 'monitor': 1, 'laptop_computer': 2, 'mug': 3, 'calculator': 4, 'projector': 5, 'keyboard': 6, 'headphones': 7, 'back_pack': 8, 'mouse': 9}
###############################################################################
class DataContainer(object):
"""Data container class for Omniglot
Arguments:
root (str): root of dataset.
num_pretrain_alphabets (int): number of alphabets to use for
meta-training.
num_classes (int): number of classes to enforce per task (optional).
transform (func): transformation to apply to each input sample.
seed (int): seed used to shuffle alphabets when creating train/val/test
splits.
**kwargs (dict): keyword arguments to pass to the
torch.utils.data.DataLoader
"""
folder = "./office_caltech_10"
target_folder = "./office_caltech_10"
def __init__(self, root="./data", pretrain_domains=[],target_domain=[],
num_classes=10, transform=DEFAULT_TRANSFORM,
seed=1, **kwargs):
self.root = root
self.pretrain_domains = pretrain_domains
self.target_domain=target_domain
self.transform = transform
self.seed = seed
self.kwargs = kwargs
path = join(os.path.expanduser(self.root), self.target_folder)
domains = list_dir(path)
if num_classes:
domains = [a for a in domains
if len(list_dir(join(path, a))) >= num_classes]
# assert self.num_pretrain_domains + TEST < len(domains), \
# 'cannot create test set'
random.seed(self.seed)
train = self.pretrain_domains
test = self.target_domain
val = self.target_domain
trainset = [office_caltech(root, [t], num_classes, HOLD_OUT,
transform=transform) for t in train]
testset = [office_caltech(root, [v], num_classes, HOLD_OUT,
transform=transform) for v in test]
valset = [office_caltech(root, [v], num_classes, HOLD_OUT,
transform=transform) for v in val]
self.domains = domains
self.domains_train = train
self.domains_test = test
self.domains_val = val
self.data_train = trainset
self.data_test = testset
self.data_val = valset
def get_loader(self, task, batch_size, iterations):
"""Returns a DataLoader for given configuration.
Arguments:
task (SubOmniglot): A SubOmniglot instance to pass to a
DataLoader instance.
batch_size (int): batch size in data loader.
iterations (int): number of batches.
"""
return DataLoader(task,
batch_size,
sampler=RandomSampler(task, iterations, batch_size),
**self.kwargs)
def get_test_loader(self, task, batch_size):
"""Returns a DataLoader for given configuration.
Arguments:
task (SubOmniglot): A SubOmniglot instance to pass to a
DataLoader instance.
batch_size (int): batch size in data loader.
iterations (int): number of batches.
"""
return DataLoader(task,
batch_size,
# sampler=RandomSampler(task, iterations, batch_size),
**self.kwargs)
def train(self, meta_batch_size, batch_size, iterations, return_idx=False):
"""Generator meta-train batch
Arguments:
meta_batch_size (int): number of tasks in batch.
batch_size (int): number of samples in each batch in the inner
(task) loop.
iterations (int): number of training steps on each task.
return_idx (int): return task ids (default=False).
"""
n_tasks = len(self.data_train)
if n_tasks == 1:
tasks = zip([0] * meta_batch_size,
self.data_train * meta_batch_size)
else:
tasks = []
task_ids = list(range(n_tasks))
while True:
random.shuffle(task_ids)
tasks.extend([(i, self.data_train[i]) for i in task_ids])
if len(tasks) >= meta_batch_size:
break
tasks = tasks[:meta_batch_size]
task_ids, task_data = zip(*tasks)
task_data = [self.get_loader(t, batch_size, iterations)
for t in task_data]
if return_idx:
return list(zip(task_ids, task_data))
return task_data
def val(self, batch_size, iterations, return_idx=False):
"""Generator meta-validation batch
Arguments:
batch_size (int): number of samples in each batch in the inner
(task) loop.
iterations (int): number of training steps on each task.
return_idx (int): return task ids (default=False).
"""
n = len(self.data_train)
tsk = [i+n for i in range(len(self.data_val))]
tasks = [self.get_loader(d, batch_size, iterations)
for d in self.data_val]
if return_idx:
return list(zip(tsk, tasks))
return tasks
def test(self, batch_size):
"""Generator meta-test batch
Arguments:
batch_size (int): number of samples in each batch in the inner
(task) loop.
iterations (int): number of training steps on each task.
return_idx (int): return task ids (default=False).
"""
# n = len(self.data_train) + len(self.data_val)
# tsk = [i+n for i in range(len(self.data_test))]
# tasks = [self.get_loader(d, batch_size, 5)
# for d in self.data_test]
self.data_test[0].train()
tasks=DataLoader(self.data_test[0], batch_size=batch_size, shuffle=True,
num_workers=4, drop_last=True)
# if return_idx:
# return list(zip(tsk, tasks))
return tasks
def test_warp(self, batch_size):
"""Generator meta-test batch
Arguments:
batch_size (int): number of samples in each batch in the inner
(task) loop.
iterations (int): number of training steps on each task.
return_idx (int): return task ids (default=False).
"""
# n = len(self.data_train) + len(self.data_val)
# tsk = [i+n for i in range(len(self.data_test))]
tasks = [self.get_loader(d, batch_size, 5)
for d in self.data_test]
# self.data_test[0].train()
# tasks=DataLoader(self.data_test[0], batch_size=batch_size, shuffle=True,
# num_workers=4, drop_last=True)
# if return_idx:
# return list(zip(tsk, tasks))
return tasks
class office_caltech(Dataset):
"""Data class for Omniglotamples a that subs specified number of alphabets.
Arguments:
root (str): root of the Omniglot dataset.
alphabets (int): number of alphabets to use in the creation of the
dataset.
num_classes (int): number of classes to enforce per task (optional).
hold_out (int): number of samples per character to hold for validation
set (optional).
transform (func): transformation to apply to each input sample.
seed (int): seed used to shuffle alphabets when creating train/val/test
splits.
"""
folder = "office_caltech_10"
target_folder = "office_caltech_10"
def __init__(self, root, domains, num_classes=None, hold_out=None,
transform=None, seed=None):
self.root = root
self.domains = domains
self.num_classes = num_classes
self.hold_out = hold_out #rate for validation
self.transform = transform
self.target_transform = None
self.seed = seed
self.target_folder = join(self.root, self.target_folder)
self._domains = [a for a in list_dir(self.target_folder)
if a in self.domains]
self._categories= sum(
[[join(a, c) for c in list_dir(join(self.target_folder, a))]
for a in self._domains], [])
# print(dict([(s.split("/")[1],index) for index,s in enumerate(self._categories)]))
if seed:
random.seed(seed)
random.shuffle(self._categories)
if self.num_classes:
self._categories = self._categories[:num_classes]
# self._categories_images = [
# [(image, idx) for image in
# list_files(join(self.target_folder, category), '.jpg')]
# for idx, category in enumerate(self._categories)
# ]
self._train_category_images = []
self._val_category_images = []
for idx, category in enumerate(self._categories):
train_characters = []
val_characters = []
path_category=join(self.target_folder, category)
for img_count, image in enumerate(
list_files(path_category, '.jpg')):
image_path=join(path_category,image)
# print(image_path)
if hold_out and img_count < hold_out:
val_characters.append((image_path, CAT2LABEL[category.split("/")[1]]))
else:
train_characters.append((image_path, CAT2LABEL[category.split("/")[1]]))
self._train_category_images.append(train_characters)
self._val_category_images.append(val_characters)
#test
# print(self._train_category_images)
# for category in self._train_category_images:
# for (image,idx) in category:
# if "amazon" in image:
# print("image",image,idx)
self._flat_train_character_images = sum(
self._train_category_images, [])
self._flat_val_character_images = sum(
self._val_category_images, [])
self._train = True
self._set_images()
# self.image_count=set()
# self.count=0
# print("creat")
def __getitem__(self, index):
path_img, label = self._flat_character_images[index]
img = Image.open(path_img).convert('RGB') # 0~255
# print("img:",path_img,"label",label)
# self.image_count.add(path_img)
# self.count+=1
# if index>800:
# print(index)
# print(self.count)
# if self.count%30==0:
# print("domainS",self.domains,"len",len(self.image_count))
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self._flat_character_images)
def train(self):
"""Train mode"""
self._train = True
self._set_images()
def eval(self):
"""Eval mode"""
self._train = False
self._set_images()
def _set_images(self):
"""Set images"""
if self._train:
self._flat_character_images = self._flat_train_character_images
else:
self._flat_character_images = self._flat_val_character_images
class RandomSampler(Sampler):
r"""Samples elements randomly with replacement (if iterations > data set).
Arguments:
data_source (Dataset): dataset to sample from
iterations (int): number of samples to return on each call to __iter__
batch_size (int): number of samples in each batch
"""
def __init__(self, data_source, iterations, batch_size):
self.data_source = data_source
self.iterations = iterations
self.batch_size = batch_size
def __iter__(self):
if self.data_source._train:
# print("data",self.iterations * self.batch_size , len(self.data_source))
assert self.iterations * self.batch_size < len(self.data_source)
idx= torch.randperm(len(self.data_source))[0 : (self.iterations * self.batch_size) % len(self.data_source)]
# idx = torch.randperm(self.iterations * self.batch_size) % len(
# self.data_source)
else:
idx = torch.randperm(len(self.data_source))
# print("len",len(idx))
return iter(idx.tolist())
def __len__(self): # pylint: disable=protected-access
return self.iterations * self.batch_size if self.data_source._train \
else len(self.data_source)
# return len(self.data_source)
"""Data loading facilities for Omniglot experiment."""
import random
import os
from os.path import join
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler
from torch.utils.data import Dataset
from torchvision.datasets.utils import list_dir, list_files
from torchvision import transforms
from PIL import Image
DEFAULT_TRANSFORM =transforms.Compose([transforms.Resize((100,100)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)])
CAT2LABEL={'bike': 0, 'monitor': 1, 'laptop_computer': 2, 'mug': 3, 'calculator': 4, 'projector': 5, 'keyboard': 6, 'headphones': 7, 'back_pack': 8, 'mouse': 9}
###############################################################################
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
def make_dataset(image_list, labels):
if labels:
len_ = len(image_list)
images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
else:
if len(image_list[0].split()) > 2:
images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
else:
images = [(val.split()[0], int(val.split()[1])) for val in image_list]
return images
def rgb_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def l_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('L')
class ImageList(Dataset):
def __init__(self, image_list, labels=None, transform=None,target_transform=None, mode='RGB'):
imgs = make_dataset(image_list, labels)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
if mode == 'RGB':
self.loader = rgb_loader
elif mode == 'L':
self.loader = l_loader
self.return_index=0
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.return_index==1:
return img, target,index
else:
return img, target
def __len__(self):
return len(self.imgs)
import os
import torch
from torchvision import transforms
from data_pro.SSDA_data_list import Imagelists_VISDA, return_classlist
from data_pro.data_list import ImageList_idx,ImageList,PairBatchSampler
class ResizeImage():
def __init__(self, size):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
th, tw = self.size
return img.resize((th, tw))
def return_dataset(args):
base_path = '/data/maning/git/shot/data/SSDA_split/%s' % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = '/data/maning/git/shot/data/OfficeHomeDataset/'
else :
root = '/data/maning/git/shot/data/%s/' % args.dataset
image_set_file_s = \
os.path.join(base_path,
'labeled_source_images_' +
args.s + '.txt')
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
def split(train_r, source_path):
with open(source_path, 'r') as f:
data = f.readlines()
train_len = int(len(data) * train_r)
train, val = torch.utils.data.random_split(data, [train_len, len(data) - train_len])
return train, val
if args.dataset in "multi":
source_train, source_val = split(train_r=0.95, source_path=image_set_file_s)
else:
source_train, source_val = split(train_r=0.90, source_path=image_set_file_s)
print("source_train and val num", len(source_train), len(source_val))
source_dataset = ImageList(source_train, root=root,
transform=data_transforms['train'])
source_val_dataset = ImageList(source_val, root=root,
transform=data_transforms['val'])
target_dataset = ImageList(open(image_set_file_t).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_val = ImageList(open(image_set_file_t_val).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_unl = ImageList_idx(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_test = ImageList(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['test'])
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
# if args.net == 'alexnet':
# bs = 20
# else:
# bs = 16
bs=args.batch_size*2 # for KD term
if args.skd_src==1:
source_loader = torch.utils.data.DataLoader(source_dataset,batch_sampler=PairBatchSampler(source_dataset, args.batch_size),num_workers=args.worker)
else:
source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs,
num_workers=args.worker, shuffle=True,
drop_last=False)
source_val_loader = torch.utils.data.DataLoader(source_val_dataset, batch_size=bs,
num_workers=args.worker, shuffle=False,
drop_last=False)
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_val = \
torch.utils.data.DataLoader(target_dataset_val,
batch_size=min(bs,
len(target_dataset_val)),
num_workers=args.worker,
shuffle=True, drop_last=False)
if args.skd_src==1:
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_sampler=PairBatchSampler(target_dataset_unl, args.batch_size),num_workers=args.worker)
else:
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_test = \
torch.utils.data.DataLoader(target_dataset_test,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
return source_loader, source_val_loader, target_loader, target_loader_unl, \
target_loader_val, target_loader_test, target_dataset_unl
def return_dataset_test(args):
base_path = './data/txt/%s' % args.dataset
root = './data/%s/' % args.dataset
image_set_file_s = os.path.join(base_path, args.source + '_all' + '.txt')
image_set_file_test = os.path.join(base_path,
'unlabeled_target_images_' +
args.target + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
target_dataset_unl = Imagelists_VISDA(image_set_file_test, root=root,
transform=data_transforms['test'],
test=True)
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
if args.net == 'alexnet':
bs = 32
else:
bs = 24
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 2, num_workers=3,
shuffle=False, drop_last=False)
return target_loader_unl, class_list
import os
import torch
from torchvision import transforms
from data_pro.SSDA_data_list import Imagelists_VISDA, return_classlist
from data_pro.data_list import ImageList_idx,ImageList
import collections
import torch
import sys
project_root="."
class ResizeImage():
def __init__(self, size):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
th, tw = self.size
return img.resize((th, tw))
class TransformTwice:
def __init__(self, transform):
self.transform = transform
def __call__(self, inp):
out1 = self.transform(inp)
out2 = self.transform(inp)
return out1, out2
def apply_train_dropout(m):
if type(m) == torch.nn.Dropout:
# print("find droput")
m.train()
def return_dataloader_by_UPS(args,netF,netC,unlabeled_data_loader):
netF.eval()
netC.eval()
base_path = project_root+"/data/SSDA_split/%s" % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = project_root+"/data/OfficeHomeDataset/"
else:
root = project_root+"/data/%s/" % args.dataset
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
# if args.uda==1:
# image_set_file_unl = \
# os.path.join(base_path,
# 'labeled_source_images_' +
# args.t+".txt")
# else:
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
netF.apply(apply_train_dropout)# MC dropout
netC.apply(apply_train_dropout)
bs = args.batch_size
unlabel_target_list = open(image_set_file_unl).readlines()
target_list = open(image_set_file_t).readlines()
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_dataset_unl.return_index=True
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
start_test = True
with torch.no_grad():
all_index=[]
for step, ((inputs1,inputs2), labels,index) in enumerate(target_loader_unl):
# print(step,index)
# = data[0]
# labels = data[1]
inputs1,inputs2 = inputs1.cuda(),inputs2.cuda()
# labels = labels.cuda() # 2020 07 06
# inputs = inputs
output=[]
batch_s=inputs2.shape[0]
repeat=5
for i in range(repeat):
outputs1 = netC(netF(inputs1)).cpu()
outputs2 = netC(netF(inputs2)).cpu()
output.append(outputs1)
output.append(outputs2)
output=torch.cat(output,dim=0).view(2*repeat,batch_s,-1)
# print("std",torch.std(output,dim=0))
outputs=torch.mean(output,dim=0)
# print("outputs",outputs.shape)
# print("mean",outputs)
# outputs,margin_logits = netC(netF(inputs),labels)
# labels = labels.cpu() # 2020 07 06
# step=torch.tensor(step).int().reshape(-1)
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_index= index
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_index= torch.cat((all_index, index), 0)
# print(type(index))
# all_index.append(index)
# print("len",len(all_index),len(all_output))
_, pred_label = torch.max(all_output, 1)
# _, predict = torch.max(all_output, 1)
# res = torch.squeeze(predict).float() == all_label
# accuracy = torch.sum(res).item() / float(all_label.size()[0])
entropy = Entropy(torch.nn.Softmax(dim=1)(all_output))
# mean_ent = torch.mean(entropy).cpu().data.item()
# entropy for each class
class_to_mean_entropy = list(0. for i in range(args.class_num))
# correct = list(0. for i in range(args.class_num))
total_add = list(0. for i in range(args.class_num))
class2index=collections.defaultdict(list)
classforentropy=collections.defaultdict(list)
index2all_index=collections.defaultdict(dict)
max_num_per_class=args.max_num_per_class
# print(all_index.shape,pred_label.shape)
for label in range(args.class_num):
class_to_mean_entropy[label] = torch.mean(entropy[pred_label == label]).item()
# classforentropy[label]=list(entropy[pred_label == label].numpy())
index2all_index[label]=dict(list(zip(range(len(all_index[pred_label == label])),all_index[pred_label == label].numpy())))
# print("index2all_label",index2all_index[label])
classentropy=entropy[pred_label == label]
if len(classentropy)<max_num_per_class:
indexes=range(len(classentropy))
else:
preds, indexes = torch.topk(classentropy, max_num_per_class, largest=False)
indexes=indexes.numpy()
# print("entroy_len",len(entropy[pred_label == label]),"index_len",len(all_index[pred_label == label]))
# print("indexes",indexes)
for ind in indexes:
class2index[label].append(index2all_index[label][ind]) # local index to global index
# for index in all_index:
# label=pred_label[index].item()
# if entropy[index] < class_to_mean_entropy[label]*0.5:
# class2index[label].append(index)
#
# # print(class2index.items())
#
#
# # assert
# # for index , label in enumerate(all_label):
# # print(label,unlabel_target_list[index])
#
total=0
acc=0
line2remove=[]
if args.uda == 1:
target_list=[]
for psudo_label, indexes in class2index.items():
for index in indexes:
line= unlabel_target_list[index]
psudo_line=line.split(" ")[0] + " " + str(psudo_label)
target_list.append(psudo_line)
line2remove.append(line)
total_add[int(line.split(" ")[1])] += 1
# if total_add[int(line.split(" ")[1])] >= max_num_per_class:
# break
# print(line.split(" ")[1],str(psudo_label))
if int(line.split(" ")[1])==psudo_label:
acc+=1
total+=1
# print("acc of psudo label",acc*1.0/total)
# print("added number for each class ",sum(total_add),total_add)
before_remove=len(unlabel_target_list)
for line in line2remove :
unlabel_target_list.remove(line)
# print("remove",line)
print("removed data from unlalbed dataset",before_remove," to ",len(unlabel_target_list))
# print(target_list)
target_dataset = Imagelists_VISDA(target_list, root=root,
transform=data_transforms['val'])
# print("target len", len(unlabel_target_list))
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl.return_index=True
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
# pin_memory=True,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
netF.train()
netC.train()
return target_loader, target_loader_unl,acc*1.0/total
def return_dataset(args):
base_path = project_root + "/data/SSDA_split/%s" % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = project_root + "/data/OfficeHomeDataset/"
else:
root = project_root + "/data/%s/" % args.dataset
image_set_file_s = \
os.path.join(base_path,
'labeled_source_images_' +
args.s + '.txt')
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
def split(train_r, source_path):
with open(source_path, 'r') as f:
data = f.readlines()
train_len = int(len(data) * train_r)
train, val = torch.utils.data.random_split(data, [train_len, len(data) - train_len])
return train, val
if args.dataset in "multi":
source_train, source_val = split(train_r=0.95, source_path=image_set_file_s)
else:
source_train, source_val = split(train_r=0.90, source_path=image_set_file_s)
print("source_train and val num", len(source_train), len(source_val))
source_dataset = Imagelists_VISDA(source_train, root=root,
transform=data_transforms['train'])
source_val_dataset = Imagelists_VISDA(source_val, root=root,
transform=data_transforms['val'])
target_dataset = Imagelists_VISDA(open(image_set_file_t).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_val = Imagelists_VISDA(open(image_set_file_t_val).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_unl.return_index = True
target_dataset_test = Imagelists_VISDA(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['test'])
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
# if args.net == 'alexnet':
# bs = 20
# else:
# bs = 16
bs=args.batch_size
source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs,
num_workers=args.worker, shuffle=True,
drop_last=False)
source_val_loader = torch.utils.data.DataLoader(source_val_dataset, batch_size=bs,
num_workers=args.worker, shuffle=False,
drop_last=False)
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_val = \
torch.utils.data.DataLoader(target_dataset_val,
batch_size=min(bs,
len(target_dataset_val)),
num_workers=args.worker,
shuffle=True, drop_last=False)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_test = \
torch.utils.data.DataLoader(target_dataset_test,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
return source_loader, source_val_loader, target_loader, target_loader_unl, \
target_loader_val, target_loader_test, class_list
def return_dataset_test(args):
base_path = './data/txt/%s' % args.dataset
root = './data/%s/' % args.dataset
image_set_file_s = os.path.join(base_path, args.source + '_all' + '.txt')
image_set_file_test = os.path.join(base_path,
'unlabeled_target_images_' +
args.target + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
target_dataset_unl = Imagelists_VISDA(image_set_file_test, root=root,
transform=data_transforms['test'],
test=True)
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
if args.net == 'alexnet':
bs = 32
else:
bs = 24
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 2, num_workers=3,
shuffle=False, drop_last=False)
return target_loader_unl, class_list
def Entropy(input_):
bs = input_.size(0)
epsilon = 1e-5
entropy = -input_ * torch.log(input_ + epsilon)
entropy = torch.sum(entropy, dim=1)
return entropy
\ No newline at end of file
import os
import torch
from torchvision import transforms
from data_pro.SSDA_data_list import Imagelists_VISDA, return_classlist
from data_pro.text_data_list import Text
import collections
import torch
import sys
import random
import numpy as np
import pickle
project_root="."
class ResizeImage():
def __init__(self, size):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
th, tw = self.size
return img.resize((th, tw))
class TransformTwice:
def __init__(self, transform):
self.transform = transform
def __call__(self, inp):
out1 = self.transform(inp)
out2 = self.transform(inp)
return out1, out2
def apply_train_dropout(m):
if type(m) == torch.nn.Dropout:
# print("find droput")
m.train()
def return_dataloader_by_UPS(args,netF,netC,unlabeled_data_loader):
netF.eval()
netC.eval()
base_path = project_root+"/data/SSDA_split/%s" % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = project_root+"/data/OfficeHomeDataset/"
else:
root = project_root+"/data/%s/" % args.dataset
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
if args.uda==1:
image_set_file_unl = \
os.path.join(base_path,
'labeled_source_images_' +
args.t+".txt")
else:
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
netF.apply(apply_train_dropout)# MC dropout
netC.apply(apply_train_dropout)
bs = args.batch_size
unlabel_target_list = open(image_set_file_unl).readlines()
target_list = open(image_set_file_t).readlines()
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_dataset_unl.return_index=True
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
start_test = True
with torch.no_grad():
all_index=[]
for step, ((inputs1,inputs2), labels,index) in enumerate(target_loader_unl):
# print(step,index)
# = data[0]
# labels = data[1]
inputs1,inputs2 = inputs1.cuda(),inputs2.cuda()
# labels = labels.cuda() # 2020 07 06
# inputs = inputs
output=[]
batch_s=inputs2.shape[0]
repeat=5
for i in range(repeat):
outputs1 = netC(netF(inputs1)).cpu()
outputs2 = netC(netF(inputs2)).cpu()
output.append(outputs1)
output.append(outputs2)
output=torch.cat(output,dim=0).view(2*repeat,batch_s,-1)
# print("std",torch.std(output,dim=0))
outputs=torch.mean(output,dim=0)
# print("outputs",outputs.shape)
# print("mean",outputs)
# outputs,margin_logits = netC(netF(inputs),labels)
# labels = labels.cpu() # 2020 07 06
# step=torch.tensor(step).int().reshape(-1)
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_index= index
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_index= torch.cat((all_index, index), 0)
# print(type(index))
# all_index.append(index)
# print("len",len(all_index),len(all_output))
_, pred_label = torch.max(all_output, 1)
# _, predict = torch.max(all_output, 1)
# res = torch.squeeze(predict).float() == all_label
# accuracy = torch.sum(res).item() / float(all_label.size()[0])
entropy = Entropy(torch.nn.Softmax(dim=1)(all_output))
# mean_ent = torch.mean(entropy).cpu().data.item()
# entropy for each class
class_to_mean_entropy = list(0. for i in range(args.class_num))
# correct = list(0. for i in range(args.class_num))
total_add = list(0. for i in range(args.class_num))
class2index=collections.defaultdict(list)
classforentropy=collections.defaultdict(list)
index2all_index=collections.defaultdict(dict)
max_num_per_class=args.max_num_per_class
# print(all_index.shape,pred_label.shape)
for label in range(args.class_num):
class_to_mean_entropy[label] = torch.mean(entropy[pred_label == label]).item()
# classforentropy[label]=list(entropy[pred_label == label].numpy())
index2all_index[label]=dict(list(zip(range(len(all_index[pred_label == label])),all_index[pred_label == label].numpy())))
# print("index2all_label",index2all_index[label])
classentropy=entropy[pred_label == label]
if len(classentropy)<max_num_per_class:
indexes=range(len(classentropy))
else:
preds, indexes = torch.topk(classentropy, max_num_per_class, largest=False)
indexes=indexes.numpy()
# print("entroy_len",len(entropy[pred_label == label]),"index_len",len(all_index[pred_label == label]))
# print("indexes",indexes)
for ind in indexes:
class2index[label].append(index2all_index[label][ind]) # local index to global index
# for index in all_index:
# label=pred_label[index].item()
# if entropy[index] < class_to_mean_entropy[label]*0.5:
# class2index[label].append(index)
#
# # print(class2index.items())
#
#
# # assert
# # for index , label in enumerate(all_label):
# # print(label,unlabel_target_list[index])
#
total=0
acc=0
line2remove=[]
if args.uda == 1:
target_list=[]
for psudo_label, indexes in class2index.items():
for index in indexes:
line= unlabel_target_list[index]
psudo_line=line.split(" ")[0] + " " + str(psudo_label)
target_list.append(psudo_line)
line2remove.append(line)
total_add[int(line.split(" ")[1])] += 1
# if total_add[int(line.split(" ")[1])] >= max_num_per_class:
# break
# print(line.split(" ")[1],str(psudo_label))
if int(line.split(" ")[1])==psudo_label:
acc+=1
total+=1
# print("acc of psudo label",acc*1.0/total)
# print("added number for each class ",sum(total_add),total_add)
before_remove=len(unlabel_target_list)
for line in line2remove :
unlabel_target_list.remove(line)
# print("removed data from unlalbed dataset",before_remove," to ",len(unlabel_target_list))
# print(target_list)
target_dataset = Imagelists_VISDA(target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl.return_index=True
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
# pin_memory=True,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
netF.train()
netC.train()
return target_loader, target_loader_unl,acc*1.0/total
def return_dataset(args):
base_path = project_root + "/data/SSDA_split/%s" % args.dataset
file_root = os.getcwd()
if args.dataset in "amazon":
base_path=os.path.join(file_root,"data/amazon")
image_set_file_s = \
os.path.join(base_path,
'labeled_source_' +
args.s + '.pkl')
image_set_file_t = \
os.path.join(base_path,
'labeled_target_' +
args.t + '_%d.pkl' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_' +
args.t + '.pkl')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_' +
args.t + '_%d.pkl' % (args.num))
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor()
]),
'val': transforms.Compose([
transforms.ToTensor()
]),
'test': transforms.Compose([
transforms.ToTensor()
]),
}
def split(train_r,rate=0.9):
train_len=int(len(train_r[0])*rate)
print(np.array(list(range(len(train_r[0])))).shape)
train_index, val_index=torch.utils.data.random_split(range(len(train_r[0])), [train_len, len(train_r[0]) - train_len],generator=torch.Generator().manual_seed(42))
train={}
val={}
train[0]=train_r[0][train_index]
train[1]=train_r[1][train_index]
val[0]=train_r[0][val_index]
val[1]=train_r[1][val_index]
return train, val
if args.dataset in "multi":
source_train, source_val = split(pickle.load(open(image_set_file_s,'rb')))
else:
source_train, source_val = split(pickle.load(open(image_set_file_s,'rb')))
print("source_train and val num", len(source_train[1]), len(source_val[1]))
source_dataset = Text(source_train,
transform=data_transforms['train'])
source_val_dataset = Text(source_val,
transform=data_transforms['val'])
target_dataset = Text(pickle.load(open(image_set_file_t,'rb')),
transform=data_transforms['val'])
target_dataset_val = Text(pickle.load(open(image_set_file_t_val,'rb')),
transform=data_transforms['val'])
target_dataset_unl = Text(pickle.load(open(image_set_file_unl,'rb')),
transform=data_transforms['val'])
target_dataset_unl.return_index = True
target_dataset_test = Text(pickle.load(open(image_set_file_unl,'rb')),
transform=data_transforms['test'])
# class_list = return_classlist(image_set_file_s)
# print("%d classes in this dataset" % len(class_list))
# if args.net == 'alexnet':
# bs = 20
# else:
# bs = 16
bs=args.batch_size
source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs,
num_workers=args.worker, shuffle=True,
drop_last=False)
source_val_loader = torch.utils.data.DataLoader(source_val_dataset, batch_size=bs,
num_workers=args.worker, shuffle=False,
drop_last=False)
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_val = \
torch.utils.data.DataLoader(target_dataset_val,
batch_size=min(bs,
len(target_dataset_val)),
num_workers=args.worker,
shuffle=True, drop_last=False)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_test = \
torch.utils.data.DataLoader(target_dataset_test,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
return source_loader, source_val_loader, target_loader, target_loader_unl, \
target_loader_val, target_loader_test, None
def return_dataset_test(args):
base_path = './data/txt/%s' % args.dataset
root = './data/%s/' % args.dataset
image_set_file_s = os.path.join(base_path, args.source + '_all' + '.txt')
image_set_file_test = os.path.join(base_path,
'unlabeled_target_images_' +
args.target + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
target_dataset_unl = Imagelists_VISDA(image_set_file_test, root=root,
transform=data_transforms['test'],
test=True)
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
if args.net == 'alexnet':
bs = 32
else:
bs = 24
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 2, num_workers=3,
shuffle=False, drop_last=False)
return target_loader_unl, class_list
def Entropy(input_):
bs = input_.size(0)
epsilon = 1e-5
entropy = -input_ * torch.log(input_ + epsilon)
entropy = torch.sum(entropy, dim=1)
return entropy
\ No newline at end of file
import os
import torch
from torchvision import transforms
from data_pro.SSDA_data_list import Imagelists_VISDA, return_classlist
from data_pro.data_list import ImageList_idx,ImageList
import collections
import torch
import sys
import pickle
# project_root="."
import torch
import os
import random
import gensim # word2vec预训练加载
# # import jieba #分词
from torch import nn
import numpy as np
from numpy import *
# from torch.utils.data import Dataset,DataLoader
import torch.optim as optim
from tensorboardX import SummaryWriter
from tqdm import tqdm
# from zhconv import convert #简繁转换
import pickle
from torch.utils.data import Dataset,DataLoader
from torch.nn.utils.rnn import pad_sequence
class CommentDataSet(Dataset):
def __init__(self, data_text, word2idx, idx2word,memory_size,max_sentence_size):
self.data_text = data_text
self.word2idx = word2idx
self.idx2word = idx2word
self.memory_size, self.max_sentence_size=memory_size,max_sentence_size
self.data, self.label = self.get_data_label()
self.get_index=0
def __getitem__(self, idx: int):
if self.get_index==1:
# print("*")
return self.data[idx], self.label[idx], idx
else:
return self.data[idx], self.label[idx]
def __len__(self):
return len(self.data)
def get_data_label(self):
data = []
label = []
for i,(review,domain,y) in enumerate(self.data_text):
words_to_idx=[]
for s_cnt, scetence in enumerate(review):
if s_cnt==self.max_sentence_size: break
for w_cnt, word in enumerate(scetence):
if w_cnt==self.memory_size: break
try:
index=self.word2idx[word]
except BaseException:
index=0
words_to_idx.append(index)
label.append(torch.tensor(y, dtype=torch.int64))
data.append(words_to_idx)
# with open(self.data_path, 'r', encoding='UTF-8') as f:
# lines = f.readlines()
# for line in lines:
# try:
# label.append(torch.tensor(int(line[0]), dtype=torch.int64))
# except BaseException: # 遇到首个字符不是标签的就跳过比如空行,并打印
# print('not expected line:' + line)
# continue
# # line = convert(line, 'zh-cn') # 转换成大陆简体
# line_words = re.split(r'[\s]', line)[1:-1] # 按照空字符\t\n 空格来切分
# words_to_idx = []
# for w in line_words:
# try:
# index = self.word2idx[w]
# except BaseException:
# index = 0 # 测试集,验证集中可能出现没有收录的词语,置为0
# # words_to_idx = [self.word2idx[w] for w in line_words]
# words_to_idx.append(index)
# data.append(torch.tensor(words_to_idx, dtype=torch.int64))
return data, label
def return_dataset(args):
source_domain, target_domain=args.source_domain, args.target_domain
def getVocab(data):
"""
Get the frequency of each feature in the file named fname.
"""
vocab = {}
for review, _, _, in data:
for sentence in review:
for word in sentence:
vocab[word] = vocab.get(word, 0) + 1
return vocab
def get_review(f, domain, label):
reviews = []
y = 1 # sentiment label
if label == "positive":
y = 1
elif label == "negative":
y = 0
with open(f, 'rb') as F:
token_list = pickle.load(F)
for tokens in token_list:
# print tokens,"\n"
reviews.append((tokens, domain, y))
return reviews
def load_bin_vec(fname, vocab):
"""
Loads 300x1 word vecs from Google (Mikolov) word2vec
"""
word_vecs = {}
with open(fname, "rb") as f:
header = f.readline()
vocab_size, layer1_size = map(int, header.split())
binary_len = np.dtype('float32').itemsize * layer1_size
cnt = 0
for line in range(vocab_size):
cnt += 1
print("line", cnt, line)
word = []
while True:
ch = f.read(1)
if ch == ' ':
word = ''.join(word)
break
if ch != '\n':
word.append(ch)
# print("line", cnt, ch)
if word in vocab:
word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32')
else:
f.read(binary_len)
return word_vecs
def load_bin_vec_with_gensim(fname, vocab):
"""
Loads 300x1 word vecs from Google (Mikolov) word2vec
"""
word2vec_model = gensim.models.KeyedVectors.load_word2vec_format(fname, binary=True)
# print(word2vec_model.__dict__['vectors'].shape)
word_vecs = {}
for word, _ in word2vec_model.vocab.items():
if word in vocab:
# print(type(word2vec_model[word]))
word_vecs[word] = np.array(word2vec_model[word], dtype='float32')
return word_vecs
def add_unknown_words(word_vecs, vocab, min_df=1, dim=300):
"""
For words that occur in at least min_df documents, create a separate word vector.
0.25 is chosen so the unknown vectors have (approximately) same variance as pre-trained ones
"""
for word in vocab:
if word not in word_vecs and vocab[word] >= min_df:
word_vecs[word] = np.random.uniform(-0.25, 0.25, dim)
def get_w2vec(vocab, FLAGS):
"""
Get word matrix. W[i] is the vector for word indexed by i
"""
FLAGS.w2v_path = os.path.join(os.getcwd(), FLAGS.w2v_path)
word_vecs = load_bin_vec_with_gensim(FLAGS.w2v_path, vocab)
add_unknown_words(word_vecs, vocab)
dim = len(list(word_vecs.values())[0])
vocab_size = len(word_vecs)
word_idx_map = dict()
idx_word_map = dict()
W = np.zeros(shape=(vocab_size + 1, dim), dtype='float32')
W[0] = np.zeros(dim, dtype='float32')
i = 1
for word in word_vecs:
W[i] = word_vecs[word]
word_idx_map[word] = i
idx_word_map[i] = word
i += 1
return W, word_idx_map, idx_word_map
def save_domain_word(file_path, emb, w2idx, idx2w):
myword2vector = {}
myword2vector["word_embedding"] = emb
myword2vector["word2idx"] = w2idx
myword2vector["idx2word"] = idx2w
pickle.dump(myword2vector, open(file_path, "wb"))
print("emb_saved")
def load_domain_word(file_path):
myword2vector = pickle.load(open(file_path, "rb"))
print("emb_loaded")
return myword2vector["word_embedding"], myword2vector["word2idx"], myword2vector["idx2word"]
def mycollate_fn(data):
# print(len(data))
# 这里的data是getittem返回的(input,label)的二元组,总共有batch_size个
data.sort(key=lambda x: len(x[0]), reverse=True) # 根据input来排序
data_length = [len(sq[0]) for sq in data]
input_data = []
label_data = []
idex=[]
for i in data:
input_data.append(torch.from_numpy(np.array(i[0])))
label_data.append(i[1])
if len(data[0])==3:
idex.append(i[2])
input_data = pad_sequence(input_data, batch_first=True, padding_value=0)
label_data = torch.tensor(label_data)
# idex=torch.tensor(idex)s
if len(data[0]) == 3:
# print("return idx")
return input_data, label_data,idex, data_length
else :
return input_data, label_data, data_length
def split_target(shot,unlabeled_target):
labeled_target=[]
test_data=[]
val_target = []
positive_cnt,negtive_cnt=0,0
random.seed(0)
indexes=list(range(len(unlabeled_target)))
random.shuffle(indexes)
# print("index",indexes)
p_cnt,n_cnt=0,0
for index in indexes:
(review, domain, label)=unlabeled_target[index]
if label==0 and negtive_cnt<shot:
labeled_target.append((review,domain,label))
negtive_cnt+=1
elif label==1 and positive_cnt<shot:
labeled_target.append((review,domain,label))
positive_cnt+=1
else:
if random.random()<0.05:
val_target.append((review,domain,label))
else:
test_data.append((review,domain,label))
# if label==1:
# p_cnt+=1
# else:v
# n_cnt+=1
# print("split labeled, shot",shot,"labeled _len",len(labeled_target),"vallen",len(val_target),len(unlabeled_target),"test_len",len(test_data),p_cnt,n_cnt)
return labeled_target,val_target,test_data
train_data = []
test_data = []
val_data = []
source_unlabeled_data = []
target_unlabeled_data = []
src, tar = 1, 0
root_path=os.path.join(os.getcwd(),"data/han_amazon/")
print("source domain: ", source_domain, "target domain:", target_domain)
# load training data
for (mode, label) in [("train", "positive"), ("train", "negative")]:
fname = root_path+"%s/tokens_%s.%s" % (source_domain, mode, label)
train_data.extend(get_review(fname, src, label))
print ("train-size: ", len(train_data))
# load validation data
for (mode, label) in [("test", "positive"), ("test", "negative")]:
fname = root_path+"/%s/tokens_%s.%s" % (source_domain, mode, label)
val_data.extend(get_review(fname, src, label))
print ("val-size: ", len(val_data))
# load testing data
for (mode, label) in [("train", "positive"), ("train", "negative"), ("test", "positive"), ("test", "negative")]:
fname = root_path+"%s/tokens_%s.%s" % (target_domain, mode, label)
test_data.extend(get_review(fname, tar, label))
print("test-size: ", len(test_data))
# load unlabeled data
for (mode, label) in [("train", "unlabeled")]:
fname = root_path+"%s/tokens_%s.%s" % (source_domain, mode, label)
source_unlabeled_data.extend(get_review(fname, src, label))
fname = root_path+"%s/tokens_%s.%s" % (target_domain, mode, label)
target_unlabeled_data.extend(get_review(fname, tar, label))
print("unlabeled-size: ", len(source_unlabeled_data), len(target_unlabeled_data))
vocab = getVocab(train_data + val_data + test_data + source_unlabeled_data + target_unlabeled_data)
print ("vocab-size: ", len(vocab))
# output_dir = "./work/logs/"
# if not os.path.exists(output_dir):
# os.mkdir(output_dir)
data = train_data + val_data + test_data + source_unlabeled_data + target_unlabeled_data
source_data = train_data+val_data+source_unlabeled_data
target_data = target_unlabeled_data
max_story_size = max(map(len, (pairs[0] for pairs in data)))
mean_story_size = int(np.mean([len(pairs[0]) for pairs in data]))
sentences = map(len, (sentence for pairs in data for sentence in pairs[0]))
max_sentence_size = max(sentences)
# mean_sentence_size = int(mean(sentences))
memory_size = min(args.memory_size, max_story_size)
# print("max story size:", max_story_size)
# print("mean story size:", mean_story_size)
# print("max sentence size:", max_sentence_size)
# # print("mean sentence size:", mean_sentence_size)
# print("max memory size:", memory_size)
max_sentence_size = args.sent_size
file_path = os.path.join(os.getcwd(), './data/han_amazon/', "w2vec_" + args.source_domain + "2" + args.target_domain + ".pkl")
if os.path.exists(file_path):
word_embedding, word2idx, idx2word=load_domain_word(file_path)
else:
word_embedding, word2idx, idx2word = get_w2vec(vocab, args)
save_domain_word(file_path, word_embedding, word2idx, idx2word)
train_data = CommentDataSet(train_data, word2idx, idx2word, memory_size, max_sentence_size)
train_loader = DataLoader(train_data, batch_size=args.bs, shuffle=True,
num_workers=args.workers, collate_fn=mycollate_fn, )
val_data = CommentDataSet(val_data, word2idx, idx2word, memory_size, max_sentence_size)
validation_loader = DataLoader(val_data, batch_size=args.bs, shuffle=True,
num_workers=args.workers, collate_fn=mycollate_fn, )
labeled_target, val_target ,test_data= split_target(args.num, test_data)#sadfasdfasdfasdfasdfasdf
target_test = CommentDataSet(test_data, word2idx, idx2word, memory_size, max_sentence_size)
test_loader = DataLoader(target_test, batch_size=args.bs, shuffle=False,
num_workers=args.workers, collate_fn=mycollate_fn, )
target_dataset_unl= CommentDataSet(test_data, word2idx, idx2word, memory_size, max_sentence_size)
target_dataset_unl.get_index=1
target_loader_unl=DataLoader(target_dataset_unl, batch_size=args.bs, shuffle=True,
num_workers=args.workers, collate_fn=mycollate_fn, )
target_dataset_labeled = CommentDataSet(labeled_target, word2idx, idx2word, memory_size, max_sentence_size)
target_loader = DataLoader(target_dataset_labeled, batch_size=args.bs, shuffle=True,
num_workers=args.workers, collate_fn=mycollate_fn, )
val_target_dataset = CommentDataSet(val_target, word2idx, idx2word, memory_size, max_sentence_size)
target_loader_val = DataLoader(val_target_dataset, batch_size=args.bs, shuffle=True,
num_workers=args.workers, collate_fn=mycollate_fn, )
# return train_data, val_data, test_data, source_unlabeled_data, target_unlabeled_data, vocab,word_embedding, word2idx, idx2word,memory_size,max_sentence_size
return train_loader, validation_loader, target_loader, target_loader_unl, \
target_loader_val, test_loader, word_embedding
# return train_loader, validation_loader, test_loader,word_embedding,
def apply_train_dropout(m):
if type(m) == torch.nn.Dropout:
# print("find droput")
m.train()
def return_dataloader_by_UPS(args,netF,netC,unlabeled_data_loader):
netF.eval()
netC.eval()
base_path = project_root+"/data/SSDA_split/%s" % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = project_root+"/data/OfficeHomeDataset/"
else:
root = project_root+"/data/%s/" % args.dataset
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
if args.uda==1:
image_set_file_unl = \
os.path.join(base_path,
'labeled_source_images_' +
args.t+".txt")
else:
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
netF.apply(apply_train_dropout)# MC dropout
netC.apply(apply_train_dropout)
bs = args.batch_size
unlabel_target_list = open(image_set_file_unl).readlines()
target_list = open(image_set_file_t).readlines()
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=TransformTwice(data_transforms['val']))
target_dataset_unl.return_index=True
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
start_test = True
with torch.no_grad():
all_index=[]
for step, ((inputs1,inputs2), labels,index) in enumerate(target_loader_unl):
# print(step,index)
# = data[0]
# labels = data[1]
inputs1,inputs2 = inputs1.cuda(),inputs2.cuda()
# labels = labels.cuda() # 2020 07 06
# inputs = inputs
output=[]
batch_s=inputs2.shape[0]
repeat=5
for i in range(repeat):
outputs1 = netC(netF(inputs1)).cpu()
outputs2 = netC(netF(inputs2)).cpu()
output.append(outputs1)
output.append(outputs2)
output=torch.cat(output,dim=0).view(2*repeat,batch_s,-1)
# print("std",torch.std(output,dim=0))
outputs=torch.mean(output,dim=0)
# print("outputs",outputs.shape)
# print("mean",outputs)
# outputs,margin_logits = netC(netF(inputs),labels)
# labels = labels.cpu() # 2020 07 06
# step=torch.tensor(step).int().reshape(-1)
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_index= index
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_index= torch.cat((all_index, index), 0)
# print(type(index))
# all_index.append(index)
# print("len",len(all_index),len(all_output))
_, pred_label = torch.max(all_output, 1)
# _, predict = torch.max(all_output, 1)
# res = torch.squeeze(predict).float() == all_label
# accuracy = torch.sum(res).item() / float(all_label.size()[0])
entropy = Entropy(torch.nn.Softmax(dim=1)(all_output))
# mean_ent = torch.mean(entropy).cpu().data.item()
# entropy for each class
class_to_mean_entropy = list(0. for i in range(args.class_num))
# correct = list(0. for i in range(args.class_num))
total_add = list(0. for i in range(args.class_num))
class2index=collections.defaultdict(list)
classforentropy=collections.defaultdict(list)
index2all_index=collections.defaultdict(dict)
max_num_per_class=args.max_num_per_class
# print(all_index.shape,pred_label.shape)
for label in range(args.class_num):
class_to_mean_entropy[label] = torch.mean(entropy[pred_label == label]).item()
# classforentropy[label]=list(entropy[pred_label == label].numpy())
index2all_index[label]=dict(list(zip(range(len(all_index[pred_label == label])),all_index[pred_label == label].numpy())))
# print("index2all_label",index2all_index[label])
classentropy=entropy[pred_label == label]
if len(classentropy)<max_num_per_class:
indexes=range(len(classentropy))
else:
preds, indexes = torch.topk(classentropy, max_num_per_class, largest=False)
indexes=indexes.numpy()
# print("entroy_len",len(entropy[pred_label == label]),"index_len",len(all_index[pred_label == label]))
# print("indexes",indexes)
for ind in indexes:
class2index[label].append(index2all_index[label][ind]) # local index to global index
total=0
acc=0
line2remove=[]
if args.uda == 1:
target_list=[]
for psudo_label, indexes in class2index.items():
for index in indexes:
line= unlabel_target_list[index]
psudo_line=line.split(" ")[0] + " " + str(psudo_label)
target_list.append(psudo_line)
line2remove.append(line)
total_add[int(line.split(" ")[1])] += 1
# if total_add[int(line.split(" ")[1])] >= max_num_per_class:
# break
# print(line.split(" ")[1],str(psudo_label))
if int(line.split(" ")[1])==psudo_label:
acc+=1
total+=1
# print("acc of psudo label",acc*1.0/total)
# print("added number for each class ",sum(total_add),total_add)
before_remove=len(unlabel_target_list)
for line in line2remove :
unlabel_target_list.remove(line)
# print("removed data from unlalbed dataset",before_remove," to ",len(unlabel_target_list))
# print(target_list)
target_dataset = Imagelists_VISDA(target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(unlabel_target_list, root=root,
transform=data_transforms['val'])
target_dataset_unl.return_index=True
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
# batch_size=bs,
# pin_memory=True,
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
# pin_memory=True,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
netF.train()
netC.train()
return target_loader, target_loader_unl,acc*1.0/total
def return_dataset111111111111111111111(args):
base_path = project_root + "/data/SSDA_split/%s" % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = project_root + "/data/OfficeHomeDataset/"
else:
root = project_root + "/data/%s/" % args.dataset
image_set_file_s = \
os.path.join(base_path,
'labeled_source_images_' +
args.s + '.txt')
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
def split(train_r, source_path):
with open(source_path, 'r') as f:
data = f.readlines()
train_len = int(len(data) * train_r)
train, val = torch.utils.data.random_split(data, [train_len, len(data) - train_len])
return train, val
if args.dataset in "multi":
source_train, source_val = split(train_r=0.95, source_path=image_set_file_s)
else:
source_train, source_val = split(train_r=0.90, source_path=image_set_file_s)
print("source_train and val num", len(source_train), len(source_val))
source_dataset = Imagelists_VISDA(source_train, root=root,
transform=data_transforms['train'])
source_val_dataset = Imagelists_VISDA(source_val, root=root,
transform=data_transforms['val'])
target_dataset = Imagelists_VISDA(open(image_set_file_t).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_val = Imagelists_VISDA(open(image_set_file_t_val).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_unl = Imagelists_VISDA(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_unl.return_index = True
target_dataset_test = Imagelists_VISDA(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['test'])
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
# if args.net == 'alexnet':
# bs = 20
# else:
# bs = 16
bs=args.batch_size
source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs,
num_workers=args.worker, shuffle=True,
drop_last=False)
source_val_loader = torch.utils.data.DataLoader(source_val_dataset, batch_size=bs,
num_workers=args.worker, shuffle=False,
drop_last=False)
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_val = \
torch.utils.data.DataLoader(target_dataset_val,
batch_size=min(bs,
len(target_dataset_val)),
num_workers=args.worker,
shuffle=True, drop_last=False)
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_test = \
torch.utils.data.DataLoader(target_dataset_test,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
return source_loader, source_val_loader, target_loader, target_loader_unl, \
target_loader_val, target_loader_test, class_list
def return_dataset_test(args):
base_path = './data/txt/%s' % args.dataset
root = './data/%s/' % args.dataset
image_set_file_s = os.path.join(base_path, args.source + '_all' + '.txt')
image_set_file_test = os.path.join(base_path,
'unlabeled_target_images_' +
args.target + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
target_dataset_unl = Imagelists_VISDA(image_set_file_test, root=root,
transform=data_transforms['test'],
test=True)
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
if args.net == 'alexnet':
bs = 32
else:
bs = 24
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 2, num_workers=3,
shuffle=False, drop_last=False)
return target_loader_unl, class_list
def Entropy(input_):
bs = input_.size(0)
epsilon = 1e-5
entropy = -input_ * torch.log(input_ + epsilon)
entropy = torch.sum(entropy, dim=1)
return entropy
\ No newline at end of file
from scipy.io import loadmat
import numpy as np
import sys
sys.path.append('../utils/')
from data_pro.utils import dense_to_one_hot
base_dir = './data'
def load_svhn():
svhn_train = loadmat(base_dir + '/svhn_train_32x32.mat')
svhn_test = loadmat(base_dir + '/svhn_test_32x32.mat')
svhn_train_im = svhn_train['X']
svhn_train_im = svhn_train_im.transpose(3, 2, 0, 1).astype(np.float32)
# print('svhn train y shape before dense_to_one_hot->', svhn_train['y'].shape)
svhn_label = dense_to_one_hot(svhn_train['y'])
# print('svhn train y shape after dense_to_one_hot->',svhn_label.shape)
svhn_test_im = svhn_test['X']
svhn_test_im = svhn_test_im.transpose(3, 2, 0, 1).astype(np.float32)
svhn_label_test = dense_to_one_hot(svhn_test['y'])
svhn_train_im = svhn_train_im[:25000]
svhn_label = svhn_label[:25000]
svhn_test_im = svhn_test_im[:9000]
svhn_label_test = svhn_label_test[:9000]
# print('svhn train X shape->', svhn_train_im.shape)
# print('svhn train y shape->', svhn_label.shape)
# print('svhn test X shape->', svhn_test_im.shape)
# print('svhn test y shape->', svhn_label_test.shape)
return svhn_train_im, svhn_label, svhn_test_im, svhn_label_test
import numpy as np
from scipy.io import loadmat
import sys
sys.path.append('../utils/')
from data_pro.utils import dense_to_one_hot
base_dir = './data'
def load_syn(scale=True, usps=False, all_use=False):
syn_data = loadmat(base_dir + '/syn_number.mat')
syn_train = syn_data['train_data']
syn_test = syn_data['test_data']
syn_train = syn_train.transpose(0, 3, 1, 2).astype(np.float32)
syn_test = syn_test.transpose(0, 3, 1, 2).astype(np.float32)
syn_labels_train = syn_data['train_label']
syn_labels_test = syn_data['test_label']
train_label = syn_labels_train
inds = np.random.permutation(syn_train.shape[0])
syn_train = syn_train[inds]
train_label = train_label[inds]
test_label = syn_labels_test
# syn_train = syn_train[:25000]
# train_label = train_label[:25000]
# syn_test = syn_test[:9000]
# test_label = test_label[:9000]
train_label = dense_to_one_hot(train_label)
test_label = dense_to_one_hot(test_label)
# print('syn number train X shape->', syn_train.shape)
# print('syn number train y shape->', train_label.shape)
# print('syn number test X shape->', syn_test.shape)
# print('syn number test y shape->', test_label.shape)
return syn_train, train_label, syn_test, test_label
import numpy as np
import pickle as pkl
def load_syntraffic():
data_source = pkl.load(open('../data/data_synthetic'))
source_train = np.random.permutation(len(data_source['image']))
data_s_im = data_source['image'][source_train[:len(data_source['image'])], :, :, :]
data_s_im_test = data_source['image'][source_train[len(data_source['image']) - 2000:], :, :, :]
data_s_label = data_source['label'][source_train[:len(data_source['image'])]]
data_s_label_test = data_source['label'][source_train[len(data_source['image']) - 2000:]]
data_s_im = data_s_im.transpose(0, 3, 1, 2).astype(np.float32)
data_s_im_test = data_s_im_test.transpose(0, 3, 1, 2).astype(np.float32)
return data_s_im, data_s_label, data_s_im_test, data_s_label_test
\ No newline at end of file
import numpy as np
import os
import os.path
from PIL import Image
# def pil_loader(path):
# with open(path, 'rb') as f:
# img = Image.open(f)
# return img.convert('RGB')
# def make_dataset_fromlist(image_list):
# # with open(image_list) as f:
# image_index = [x.split(' ')[0] for x in image_list]
# # with open(image_list) as f:
# label_list = []
# selected_list = []
# for ind, x in enumerate(image_list):
# label = x.split(' ')[1].strip()
# label_list.append(int(label))
# selected_list.append(ind)
# image_index = np.array(image_index)
# label_list = np.array(label_list)
# image_index = image_index[selected_list]
# return image_index, label_list
# def return_classlist(image_list):
# with open(image_list) as f:
# label_list = []
# for ind, x in enumerate(f.readlines()):
# label = x.split(' ')[0].split('/')[-2]
# if label not in label_list:
# label_list.append(str(label))
# return label_list
class Text(object):
def __init__(self, text, root="./data/multi/",
transform=None, target_transform=None, test=False):
texts, labels = text[0],text[1]
self.texts = texts
self.labels = labels
self.transform = transform
self.target_transform = target_transform
self.root = root
self.test = test
self.return_index=False
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is
class_index of the target class.
"""
# path = os.path.join(self.root, self.imgs[index])
# target = self.labels[index]
# img = self.loader(path)
single_text=self.texts[index]
target=self.labels[index]
# if self.transform is not None:
# single_text = self.transform(single_text)
if self.target_transform is not None:
target = self.target_transform(target)
if not self.test:
if self.return_index:
return single_text, target,index
else:
return single_text, target
else:
return single_text, target, self.texts[index]
def __len__(self):
return len(self.texts)
import torch.utils.data
import torchnet as tnt
from builtins import object
import torchvision.transforms as transforms
from data_pro.datasets_ import Dataset
class PairedData(object):
def __init__(self, data_loader_A, data_loader_B, data_loader_C, data_loader_D, data_loader_t, max_dataset_size):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.data_loader_C = data_loader_C
self.data_loader_D = data_loader_D
self.data_loader_t = data_loader_t
self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False
self.max_dataset_size = max_dataset_size
def __iter__(self):
self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False
self.data_loader_A_iter = iter(self.data_loader_A)
self.data_loader_B_iter = iter(self.data_loader_B)
self.data_loader_C_iter = iter(self.data_loader_C)
self.data_loader_D_iter = iter(self.data_loader_D)
self.data_loader_t_iter = iter(self.data_loader_t)
self.iter = 0
return self
def __next__(self):
A, A_paths = None, None
B, B_paths = None, None
C, C_paths = None, None
D, D_paths = None, None
t, t_paths = None, None
try:
A, A_paths = next(self.data_loader_A_iter)
except StopIteration:
if A is None or A_paths is None:
self.stop_A = True
self.data_loader_A_iter = iter(self.data_loader_A)
A, A_paths = next(self.data_loader_A_iter)
try:
B, B_paths = next(self.data_loader_B_iter)
except StopIteration:
if B is None or B_paths is None:
self.stop_B = True
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)
try:
C, C_paths = next(self.data_loader_C_iter)
except StopIteration:
if C is None or C_paths is None:
self.stop_C = True
self.data_loader_C_iter = iter(self.data_loader_C)
C, C_paths = next(self.data_loader_C_iter)
try:
D, D_paths = next(self.data_loader_D_iter)
except StopIteration:
if D is None or D_paths is None:
self.stop_D = True
self.data_loader_D_iter = iter(self.data_loader_D)
D, D_paths = next(self.data_loader_D_iter)
try:
t, t_paths = next(self.data_loader_t_iter)
except StopIteration:
if t is None or t_paths is None:
self.stop_t = True
self.data_loader_t_iter = iter(self.data_loader_t)
t, t_paths = next(self.data_loader_t_iter)
if (self.stop_A and self.stop_B and self.stop_C and self.stop_D and self.stop_t) or self.iter > self.max_dataset_size:
self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False
raise StopIteration()
else:
self.iter += 1
return {'S1': A, 'S1_label': A_paths,
'S2': B, 'S2_label': B_paths,
'S3': C, 'S3_label': C_paths,
'S4': D, 'S4_label': D_paths,
'T': t, 'T_label': t_paths}
class UnalignedDataLoader():
def initialize(self, source, target, batch_size1, batch_size2, scale=32):
transform = transforms.Compose([
transforms.Scale(scale),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
#dataset_source1 = Dataset(source[1]['imgs'], source['labels'], transform=transform)
dataset_source1 = Dataset(source[0]['imgs'], source[0]['labels'], transform=transform)
data_loader_s1 = torch.utils.data.DataLoader(dataset_source1, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s1 = dataset_source1
dataset_source2 = Dataset(source[1]['imgs'], source[1]['labels'], transform=transform)
data_loader_s2 = torch.utils.data.DataLoader(dataset_source2, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s2 = dataset_source2
dataset_source3 = Dataset(source[2]['imgs'], source[2]['labels'], transform=transform)
data_loader_s3 = torch.utils.data.DataLoader(dataset_source3, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s3 = dataset_source3
dataset_source4 = Dataset(source[3]['imgs'], source[3]['labels'], transform=transform)
data_loader_s4 = torch.utils.data.DataLoader(dataset_source4, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s4 = dataset_source4
#for i in range(len(source)):
# dataset_source[i] = Dataset(source[i]['imgs'], source[i]['labels'], transform=transform)
dataset_target = Dataset(target['imgs'], target['labels'], transform=transform)
data_loader_t = torch.utils.data.DataLoader(dataset_target, batch_size=batch_size2, shuffle=True, num_workers=4)
self.dataset_t = dataset_target
self.paired_data = PairedData(data_loader_s1, data_loader_s2, data_loader_s3,data_loader_s4, data_loader_t,
float("inf"))
def name(self):
return 'UnalignedDataLoader'
def load_data(self):
return self.paired_data
def __len__(self):
return min(max(len(self.dataset_s1),len(self.dataset_s2),len(self.dataset_s3), len(self.dataset_s4),len(self.dataset_t)), float("inf"))
"""Dataset setting and data loader for USPS.
Modified from
https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
"""
import gzip
import os
import pickle
import urllib
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import datasets, transforms
class USPS(data.Dataset):
"""USPS Dataset.
Args:
root (string): Root directory of dataset where dataset file exist.
train (bool, optional): If True, resample from dataset randomly.
download (bool, optional): If true, downloads the dataset
from the internet and puts it in root directory.
If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in
an PIL image and returns a transformed version.
E.g, ``transforms.RandomCrop``
"""
url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
def __init__(self, root, train=True, transform=None, download=False):
"""Init USPS dataset."""
# init params
self.root = os.path.expanduser(root)
self.filename = "usps_28x28.pkl"
self.train = train
# Num of Train = 7438, Num ot Test 1860
self.transform = transform
self.dataset_size = None
# download dataset.
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." +
" You can use download=True to download it")
self.train_data, self.train_labels = self.load_samples()
if self.train:
total_num_samples = self.train_labels.shape[0]
indices = np.arange(total_num_samples)
self.train_data = self.train_data[indices[0:self.dataset_size], ::]
self.train_labels = self.train_labels[indices[0:self.dataset_size]]
self.train_data *= 255.0
self.train_data = np.squeeze(self.train_data).astype(np.uint8)
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, label = self.train_data[index], self.train_labels[index]
img = Image.fromarray(img, mode='L')
img = img.copy()
if self.transform is not None:
img = self.transform(img)
return img, label.astype("int64")
def __len__(self):
"""Return size of dataset."""
return len(self.train_data)
def _check_exists(self):
"""Check if dataset is download and in right place."""
return os.path.exists(os.path.join(self.root, self.filename))
def download(self):
"""Download dataset."""
filename = os.path.join(self.root, self.filename)
dirname = os.path.dirname(filename)
if not os.path.isdir(dirname):
os.makedirs(dirname)
if os.path.isfile(filename):
return
print("Download %s to %s" % (self.url, os.path.abspath(filename)))
urllib.request.urlretrieve(self.url, filename)
print("[DONE]")
return
def load_samples(self):
"""Load sample images from dataset."""
filename = os.path.join(self.root, self.filename)
f = gzip.open(filename, "rb")
data_set = pickle.load(f, encoding="bytes")
f.close()
if self.train:
images = data_set[0][0]
labels = data_set[0][1]
self.dataset_size = labels.shape[0]
else:
images = data_set[1][0]
labels = data_set[1][1]
self.dataset_size = labels.shape[0]
return images, labels
\ No newline at end of file
import numpy as np
from scipy.io import loadmat
import gzip
import pickle
import sys
sys.path.append('../utils/')
from data_pro.utils import dense_to_one_hot
base_dir = './data'
def load_usps(all_use=False):
#f = gzip.open('data_pro/usps_28x28.pkl', 'rb')
#data_set = pickle.load(f)
#f.close()
dataset = loadmat(base_dir + '/usps_28x28.mat')
data_set = dataset['dataset']
img_train = data_set[0][0]
label_train = data_set[0][1]
img_test = data_set[1][0]
label_test = data_set[1][1]
inds = np.random.permutation(img_train.shape[0])
img_train = img_train[inds]
label_train = label_train[inds]
img_train = img_train * 255
img_test = img_test * 255
img_train = img_train.reshape((img_train.shape[0], 1, 28, 28))
img_test = img_test.reshape((img_test.shape[0], 1, 28, 28))
#img_test = dense_to_one_hot(img_test)
label_train = dense_to_one_hot(label_train)
label_test = dense_to_one_hot(label_test)
img_train = np.concatenate([img_train, img_train, img_train, img_train], 0)
label_train = np.concatenate([label_train, label_train, label_train, label_train], 0)
# print('usps train X shape->', img_train.shape)
# print('usps train y shape->', label_train.shape)
# print('usps test X shape->', img_test.shape)
# print('usps test y shape->', label_test.shape)
return img_train, label_train, img_test, label_test
import numpy as np
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.01)
m.bias.data.normal_(0.0, 0.01)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.01)
m.bias.data.fill_(0)
def dense_to_one_hot(labels_dense):
"""Convert class labels from scalars to one-hot vectors."""
labels_one_hot = np.zeros((len(labels_dense),))
labels_dense = list(labels_dense)
for i, t in enumerate(labels_dense):
if t == 10:
t = 0
labels_one_hot[i] = t
else:
labels_one_hot[i] = t
return labels_one_hot
SNPC,SP,CS,RS
5,64.6,60.6,60.7
10,66.6,63.2,62.8
15,67.1,64.3,62.2
20,66.6,64.2,63.5
25,66.8,64.8,62.6
30,67.1,64.7,63.2
"""
draw singular values of DomainNet with resnet34
"""
import csv,os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# plt.style.use('ggplot')
import matplotlib
import numpy as np
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus']=False
# sns.set_theme(style="darkgrid")
sns.set_theme(style="whitegrid", palette="pastel")
def draw(file_paths,dir,name):
scoremarkers=["v","s","*","o","x","+"]
# accmarkers=["v","s","*","o","x"]
for i, path in enumerate(file_paths):
fmri=pd.read_csv(path,sep=',',) #header=None,names=["score"],index_col=False
num=sum(1 for line in open(path))
# fmri["score"]=fmri["score"]*100
# sns.barplot(x="alpha", y="RS", data=fmri)
# sns.barplot(x="alpha", y="CS", data=tips)
# sns.barplot(x="alpha", y="SP", data=tips)
a=0
ax=sns.lineplot(x="SNPC",y="RS",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
a=a+1
ax=sns.lineplot(x="SNPC",y="CS",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
a=a+1
ax=sns.lineplot(x="SNPC",y="SP",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
plt.xlabel("SNPC",fontsize=20)
plt.ylabel('Accuracy,%',fontsize=20)
plt.yticks(np.arange(55, 72, 5))
plt.legend([r"S $\rightarrow$ P",r"C $\rightarrow$ S", r"R $\rightarrow$ S"],loc="lower left",fontsize=12)
plt.savefig(os.path.join(dir+ name), format="pdf",bbox_inches="tight",dpi = 400)
# plt.clf()
draw(["/data/maning/git/shot/draw/SNPC.csv"
],
"/data/maning/git/shot/draw/", "SNPCAcc.pdf")
# plt.clf()
# R-C (domainnet)
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_clipart.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/", "real2clipart_alphas1.pdf")
#R-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/","real2sketch_alphas1.pdf")
# C-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/","clipart2sketch_alphas1.pdf")
# S-P
# draw([
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_painting.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/","sketch2painting_alphas1.pdf")
\ No newline at end of file
alpha,RS,CS,SP
Adaptive,68.2,66.7,69.6
0.5',66.5,66.4,67.5
1',64.7,65,66
2',58.7,62.9,61.7
3',57.2,58.6,60.9
10',56.8,58.2,60.9
"""
draw singular values of DomainNet with resnet34
"""
import csv,os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# plt.style.use('ggplot')
import matplotlib
import numpy as np
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus']=False
sns.set_theme(style="darkgrid")
def draw(file_paths,dir,name):
scoremarkers=["v","s","*","o","x","+"]
# accmarkers=["v","s","*","o","x"]
for i, path in enumerate(file_paths):
fmri=pd.read_csv(path,sep=',',header=None,names=["score"],index_col=False)
num=sum(1 for line in open(path))
# fmri["score"]=fmri["score"]*100
sns.barplot(x="alpha", y="RS", data=fmri)
# sns.barplot(x="alpha", y="CS", data=tips)
# sns.barplot(x="alpha", y="SP", data=tips)
# ax=sns.lineplot(x="alpha",y="RS",err_style = "band",ci="sd",marker=scoremarkers[i],linewidth=3,
# # hue="region", style="event",
# data=fmri)
# ax=sns.lineplot(x="alpha",y="CS",err_style = "band",ci="sd",marker=scoremarkers[i],linewidth=3,
# # hue="region", style="event",
# data=fmri)
# ax=sns.lineplot(x="alpha",y="SP",err_style = "band",ci="sd",marker=scoremarkers[i],linewidth=3,
# # hue="region", style="event",
# data=fmri)
plt.xlabel("index",fontsize=20)
plt.ylabel('singular values',fontsize=20)
# plt.yticks(np.arange(50, 80, 5))
# plt.legend([r"$R\rightarrowS$",r"$$R\rightarrowS$$", r"$$R\rightarrowS$$",r"$\alpha=3$",r"$\alpha=10$",r'$Adaptive$'],loc="upper right",fontsize=12)
plt.savefig(os.path.join(dir+ name), format="pdf",bbox_inches="tight",dpi = 400)
# plt.clf()
draw(["/data/maning/git/shot/draw/alpha_accs.csv"
],
"/data/maning/git/shot/draw/", "alphaAcc.pdf")
#R-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/","real2sketch_alphas1.pdf")
# C-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/","clipart2sketch_alphas1.pdf")
# S-P
# draw([
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_painting.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/","sketch2painting_alphas1.pdf")
\ No newline at end of file
import csv,os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# plt.style.use('ggplot')
import matplotlib
import numpy as np
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus']=False
sns.set_theme(style="darkgrid")
def draw(file_paths,dir,draw="acc"):
scoremarkers=["v","o","x","s"]
accmarkers=["v","o","x","s"]
if draw in "test_acc":
for i, path in enumerate(file_paths):
fmri=pd.read_csv(path)
fmri["test_acc"]=fmri["test_acc"]*100
ax=sns.lineplot(x="epoch",y="test_acc",err_style = "band",ci="sd", marker=accmarkers[i], linewidth=3,
# hue="region", style="event",
data=fmri)
plt.xlabel("epoch",fontsize=20)
plt.ylabel('Accuracies, (%)',fontsize=20)
plt.yticks(np.arange(35, 65, 5))
plt.legend(["EM", r'+BNM',"+Div","+LP"],loc="lower left",fontsize=15)
plt.savefig(os.path.join(dir+ "show_ACC.pdf"), format="pdf",bbox_inches="tight",dpi = 400)
else:
for i, path in enumerate(file_paths):
fmri=pd.read_csv(path)
ax=sns.lineplot(x="epoch",y="score",err_style = "band",ci="sd",marker=scoremarkers[i],linewidth=3,
# hue="region", style="event",
data=fmri)
plt.ylabel('E1 scores',fontsize=20)
plt.yticks(np.arange(0, 0.6, 0.1))
plt.legend(["EM", r'+BNM',"+Div","+LP"],loc="upper left",fontsize=15)
plt.savefig(os.path.join(dir+ "show_SCORE.pdf"), format="pdf",bbox_inches="tight",dpi = 400)
# plt.show()
def draw_single(file_path):
fmri=pd.read_csv(file_path)
fmri["test_acc"]=fmri["test_acc"]*100
fmri["zero"]=fmri["zero"]*100
fmri["score"]=fmri["score"]*100
fmri["two"]=fmri["two"]*100
ax=sns.lineplot(x="epoch",y="test_acc",err_style = "band",ci="sd", marker="v", linewidth=3, color="silver",
# hue="region", style="event",
data=fmri)
ax=sns.lineplot(x="epoch",y="zero",err_style = "band",ci="sd",marker="o",linewidth=3,
# hue="region", style="event",
data=fmri)
ax=sns.lineplot(x="epoch",y="score",err_style = "band",ci="sd",marker="s",linewidth=3,
# hue="region", style="event",
data=fmri)
ax=sns.lineplot(x="epoch",y="two",err_style = "band",ci="sd",marker="*",linewidth=3,
# hue="region", style="event",
data=fmri)
# ax=sns.lineplot(x="epoch",y="entropy",err_style = "band",ci="sd",marker="s",linewidth=3,
# # hue="region", style="event",
# data=fmri)
plt.xlabel("epochs",fontsize=20)
plt.ylabel('percentages',fontsize=20)
plt.yticks(np.arange(0, 100, 10))
plt.legend(["Acc", r'$S_{\mu=0}$' , r'$S_{\mu=1}$', r'$S_{\mu>1}$'],loc="upper right",fontsize=12)
plt.savefig(os.path.join(file_path + "show.pdf"), format="pdf",bbox_inches="tight",dpi = 400)
# plt.show()
# print("done")
# plt.show()
# draw(["/data/maning/git/shot/ssda/2021_07_10office-home/seed2021/office-home/Real_World_norm0_temp0.05_lr0.001/tb_Real_World2Clipart_lr0.001_unl_ent0.3_unl_w0.0_vat_w0_div_w0.0bnm0.0_num1.csv",
# "/data/maning/git/shot/ssda/2021_07_07office-home/seed2021/office-home/Real_World_norm0_temp0.05_lr0.001/tb_Real_World2Clipart_lr0.001_unl_ent0.1_unl_w0.0_vat_w0_div_w0.0bnm1.0_num1.csv",
# "/data/maning/git/shot/ssda/2021_07_10office-home/seed2021/office-home/Real_World_norm0_temp0.05_lr0.001/tb_Real_World2Clipart_lr0.001_unl_ent0.3_unl_w0.0_vat_w0_div_w1.0bnm0.0_num1.csv",
# "/data/maning/git/shot/ssda/2021_07_10office-home/seed2021/office-home/Real_World_norm0_temp0.05_lr0.001/tb_Real_World2Clipart_lr0.001_unl_ent0.1_unl_w0.1_vat_w0_div_w0.0bnm0.0_num1.csv"
# ],"/data/maning/git/shot/ssda/2021_07_10office-home/seed2021/office-home/Real_World_norm0_temp0.05_lr0.001/", "test_acc")
#em
# draw_single("/data/maning/git/shot/ssda/2021_02_05office-home/seed2021/Office-31/webcam_norm1_temp0.05_lr0.001/tb_webcam2amazon_lr0.001_unl_ent0.2_unl_w0.0_vat_w0_div_w0bnm0_num1.csv")
#bnm
# draw_single("/data/maning/git/shot/ssda/2021_02_05office-home/seed2021/Office-31/webcam_norm1_temp0.05_lr0.001/tb_webcam2amazon_lr0.001_unl_ent0.3_unl_w0.0_vat_w0_div_w0bnm1.0_num1.csv")
#div
# draw_single("/data/maning/git/shot/ssda/2021_02_05office-home/seed2021/Office-31/webcam_norm1_temp0.05_lr0.001/tb_webcam2amazon_lr0.001_unl_ent1.0_unl_w0.0_vat_w0_div_w1.0bnm0_num1.csv")
#lp
# draw_single("/data/maning/git/shot/ssda/2021_02_05office-home/seed2021/Office-31/webcam_norm1_temp0.05_lr0.001/tb_webcam2amazon_lr0.001_unl_ent0.3_unl_w0.3_vat_w0_div_w0bnm0_num1.csv")
\ No newline at end of file
"""
draw singular values of DomainNet with resnet34
"""
import csv,os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# plt.style.use('ggplot')
import matplotlib
import numpy as np
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus']=False
sns.set_theme(style="darkgrid")
def draw(file_paths,dir,name):
scoremarkers=["v","s","*","o","x","+"]
# accmarkers=["v","s","*","o","x"]
for i, path in enumerate(file_paths):
fmri=pd.read_csv(path,sep=',',header=None,names=["score"],index_col=False)
num=sum(1 for line in open(path))
# fmri["score"]=fmri["score"]*100
ax=sns.lineplot(x=range(1,num+1),y="score",err_style = "band",ci="sd",marker=scoremarkers[i],linewidth=3,
# hue="region", style="event",
data=fmri)
plt.xlabel("index",fontsize=20)
plt.ylabel('singular values',fontsize=20)
# plt.yticks(np.arange(0, 60, 10))
plt.legend([r"$\alpha=0.5$",r"$\alpha=1$", r"$\alpha=2$",r"$\alpha=3$",r"$\alpha=10$",r'$Adaptive$'],loc="upper right",fontsize=12)
plt.savefig(os.path.join(dir+ name), format="pdf",bbox_inches="tight",dpi = 400)
# plt.clf()
# R-C (domainnet)
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_clipart.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/", "real2clipart_alphas1.pdf")
#R-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/","real2sketch_alphas1.pdf")
# C-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/","clipart2sketch_alphas1.pdf")
# S-P
# draw([
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_painting.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/","sketch2painting_alphas1.pdf")
\ No newline at end of file
"""
draw singular values of DomainNet with resnet34
"""
import csv,os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# plt.style.use('ggplot')
import matplotlib
import numpy as np
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus']=False
# sns.set_theme(style="darkgrid")
sns.set_theme(style="whitegrid", palette="pastel")
def draw(file_paths,dir,name):
scoremarkers=["v","s","*","o","x","+"]
# accmarkers=["v","s","*","o","x"]
for i, path in enumerate(file_paths):
fmri=pd.read_csv(path,sep=',',) #header=None,names=["score"],index_col=False
num=sum(1 for line in open(path))
# fmri["score"]=fmri["score"]*100
# sns.barplot(x="alpha", y="RS", data=fmri)
# sns.barplot(x="alpha", y="CS", data=tips)
# sns.barplot(x="alpha", y="SP", data=tips)
fmri["train_acc"]=fmri["train_acc"]*100
fmri["val_acc"]=fmri["val_acc"]*100
fmri["test_acc"]=fmri["test_acc"]*100
a=0
ax=sns.lineplot(x="epoch",y="train_acc",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
a=a+1
ax=sns.lineplot(x="epoch",y="val_acc",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
a=a+1
ax=sns.lineplot(x="epoch",y="test_acc",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
plt.xlabel("epoch",fontsize=20)
plt.ylabel('Accuracy,%',fontsize=20)
plt.yticks(np.arange(55, 100, 5))
plt.legend([r"Training",r"Validation", r"Test"],loc="center right",fontsize=12)
plt.savefig(os.path.join(dir+ name), format="pdf",bbox_inches="tight",dpi = 400)
# plt.clf()
draw(["/data/maning/git/shot/draw/training-process.csv"
],
"/data/maning/git/shot/draw/", "training-process.pdf")
# plt.clf()
# R-C (domainnet)
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_clipart.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/", "real2clipart_alphas1.pdf")
#R-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/","real2sketch_alphas1.pdf")
# C-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/","clipart2sketch_alphas1.pdf")
# S-P
# draw([
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_painting.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/","sketch2painting_alphas1.pdf")
\ No newline at end of file
epoch,train_acc,val_acc,test_acc,time
0,0.9861111111111112,0.5740740740740741,0.6128710962601208,2022_10_11_12_51_51
5,1.0,0.6322751322751323,0.6668808636422053,2022_10_11_13_36_05
10,1.0,0.6349206349206349,0.6857409073383883,2022_10_11_14_19_56
15,0.9997209821428571,0.6481481481481481,0.6931628325408046,2022_10_11_15_03_27
20,1.0,0.6507936507936508,0.6984963372317183,2022_10_11_15_46_16
25,0.9992559523809523,0.6455026455026455,0.7011309600308444,2022_10_11_16_28_21
30,0.9981863839285714,0.6481481481481481,0.7021269759670994,2022_10_11_17_08_53
35,0.9988567073170732,0.6455026455026455,0.702030587328107,2022_10_11_17_49_05
40,0.9988392857142857,0.6507936507936508,0.7028338259863771,2022_10_11_18_28_37
45,0.9991155660377359,0.6507936507936508,0.7019984577817762,2022_10_11_19_07_57
from utils import *
import argparse
import os
import os.path as osp
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from data_pro.mixmatch_return_dataset import return_dataset,return_dataloader_by_UPS,return_dataloader_by_progressive_UPS
from models.SSDA_basenet import *
from models.SSDA_resnet import *
from scipy.spatial.distance import cdist
from copy import deepcopy
import contextlib
# from loss_utils import *
import scipy
import scipy.stats
from itertools import cycle
def train_source(args):
if os.path.exists(osp.join(args.output_dir, "source_C_val.pt")):
print("train_file exist,",args.output_dir)
return 0
source_loader,source_val_loader, _, _, target_loader_val, \
target_loader_test, class_list = return_dataset(args)
netF,netC,_=get_model(args)
netF=netF.cuda()
netC=netC.cuda()
param_group = []
learning_rate = args.lr
for k, v in netF.features.named_parameters():
v.requires_grad = True
param_group += [{'params': v, 'lr': learning_rate}]
for k, v in netF.bottle_neck.named_parameters():
v.requires_grad = True
param_group += [{'params': v, 'lr': learning_rate * 10}]
for k, v in netC.named_parameters():
v.requires_grad = True
param_group += [{'params': v, 'lr': learning_rate * 10}]
optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True)
# optimizer = optim.Adam(param_group, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
factor=0.5, patience=50,
verbose=True, min_lr=1e-6)
# scaler = torch.cuda.amp.GradScaler()
acc_init = 0
for epoch in (range(args.max_epoch_source)):
netF.train()
netC.train()
# loss=nn.CrossEntropyLoss().cuda()
total_losses,recon_losses,classifier_losses=[],[],[]
iter_source = iter(source_loader)
for _, (inputs_source, labels_source) in tqdm(enumerate(iter_source), leave=False):
if inputs_source.size(0) == 1:
continue
# inputs_source, labels_source = inputs_source, labels_source
inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
# print(inputs_source.shape)
# print("input:",inputs_source.shape,labels_source)
# with torch.cuda.amp.autocast():
embeddings=netF(inputs_source)
outputs_source = netC(embeddings)
# logits,margin_logits = netC(embeddings,labels_source)
classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=0.1)(outputs_source,
labels_source,T=1)
# classifier_loss = loss(outputs_source,labels_source)
total_loss=classifier_loss
# print("loss",total_loss)
total_losses.append(total_loss.item())
classifier_losses.append(classifier_loss.item())
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# scaler.scale(total_loss).backward()
# scaler.step(optimizer)
# scaler.update()
# break
# print("total_loss: {:.6f}, classify loss: {:.6f},recon_loss:{:.6f}".format())
netF.eval()
netC.eval()
scheduler.step(np.mean(total_losses))
acc_s_tr, _ = cal_acc(source_loader, netF, netC)
acc_s_te, _ = cal_acc(source_val_loader, netF, netC)
log_str = 'train_source , Task: {}, Iter:{}/{}; Accuracy = {:.2f}%/ {:.2f}%, total_loss: {:.6f}, classify loss: {:.6f},'\
.format(args.s+"2"+args.t, epoch + 1, args.max_epoch_source, acc_s_tr * 100, acc_s_te * 100,np.mean(total_losses),np.mean(classifier_losses))
args.out_file.write(log_str + '\n')
args.out_file.flush()
print(log_str + '\n')
if acc_s_te >= acc_init:
acc_init = acc_s_te
best_netF = netF.state_dict()
best_netC = netC.state_dict()
torch.save(best_netF, osp.join(args.output_dir, "source_F_val.pt"))
torch.save(best_netC, osp.join(args.output_dir, "source_C_val.pt"))
#
# torch.save(best_netF, osp.join(args.output_dir, "source_F_val.pt"))
# torch.save(best_netC, osp.join(args.output_dir, "source_C_val.pt"))
return netF, netC
def test_target(args):
_, _,_, _, _, target_loader_test, class_list = return_dataset(args)
netF, netC, _ = get_model(args)
args.modelpath = args.output_dir + '/source_F_val.pt'
netF.load_state_dict(torch.load(args.modelpath))
args.modelpath = args.output_dir + '/source_C_val.pt'
netC.load_state_dict(torch.load(args.modelpath))
netF=nn.DataParallel(netF,device_ids=[0])
netC=nn.DataParallel(netC,device_ids=[0])
netC=netC.cuda()
netF=netF.cuda()
netF.eval()
netC.eval()
acc, _,= cal_acc(target_loader_test, netF, netC)
log_str = 'test_target Task: {}, Accuracy = {:.2f}%'.format(args.s+"2"+args.t, acc * 100)
args.out_file.write(log_str + '\n')
args.out_file.flush()
print(log_str + '\n')
# file_name1=osp.join(args.output_dir, 'tar_' + args.s+"2"+args.t+"_lr"+str(args.lr)
# + "MNPC"+str(args.max_num_per_class)+ "_im"+str(args.im)
# + "_u"+str(args.lambda_u)+"_unlent"+str(args.unlent)+"_num"+str(args.num)+'embeding_pretrain.tsv')
# file_name2 = osp.join(args.output_dir, 'tar_' + args.s + "2" + args.t + "_lr" + str(args.lr)
# + "MNPC" + str(args.max_num_per_class) + "_im" + str(args.im)
# + "_u" + str(args.lambda_u) + "_unlent" + str(args.unlent) + "_num" + str(
# args.num) + 'meta_pretrain.tsv')
# np.savetxt(file_name1, emb,delimiter='\t')
# np.savetxt(file_name2, label, delimiter='\t')
def test_target_svd(args):
source_loader, _,_, _, _, target_loader_test, class_list = return_dataset(args)
netF, netC, _ = get_model(args)
name='tar_' + args.s+"2"+args.t+"_lr"+str(args.lr)+ "MNPC"+str(args.max_num_per_class)+ "_im"+str(args.im)+ "_u"+str(args.lambda_u)+"_unlent"+str(args.unlent)+"_nonlinear"+str(args.nonlinear)+"_wnl"+str(args.wnl)+"_alpha"+str(args.alpha)+"_num"+str(args.num)+"_"
args.modelpath = args.output_dir + '/{}target_F.pt'.format(name)
netF.load_state_dict(torch.load(args.modelpath))
args.modelpath = args.output_dir + '/{}target_C.pt'.format(name)
netC.load_state_dict(torch.load(args.modelpath))
netC=netC.cuda()
netF=netF.cuda()
netF.eval()
netC.eval()
all_fea,all_output,all_label=compute(target_loader_test,netF,netC,args)
len_data=len(all_fea)
bs=args.batch_size
target_s=[]
target_u=[]
for i in range(0,len_data-bs,bs):
fea=all_fea[i:i+bs]
u,s,v=torch.svd(fea.t())
target_s.append(s.detach.cpu())
target_u.append(u.detach.cpu())
target_S= torch.mean(torch.cat(target_s,dim=0),dim=0)
target_U= torch.mean(torch.cat(target_u,dim=0),dim=0)
all_fea,all_output,all_label=compute(source_loader,netF,netC,args)
len_data=len(all_fea)
bs=args.batch_size
source_s=[]
source_u=[]
for i in range(0,len_data-bs,bs):
fea=all_fea[i:i+bs]
u,s,v=torch.svd(fea.t())
source_s.append(s.detach.cpu())
source_u.append(u.detach.cpu())
source_S= torch.mean(torch.cat(target_s,dim=0),dim=0)
source_U= torch.mean(torch.cat(target_u,dim=0),dim=0)
p_s, cospa, p_t = torch.svd(torch.mm(source_U.t(), target_U))
sinpa = torch.sqrt(1-torch.pow(cospa,2))
subspace_distance=torch.norm(sinpa,1)
source_S=source_S.numpy()
target_S=target_S.numpy()
print("subspace_distance: ", subspace_distance)
np.savetxt(args.output_dir+"/"+name+".csv",source_S.reshpae(-1),delimiter=",")
np.savetxt(args.output_dir+"/"+name+".csv",source_S.reshpae(-1),delimiter=",")
def train_target(args):
source_loader, source_val_loader, target_loader, target_loader_unl, \
target_loader_val, target_loader_test, (labeled_target, unlabeled_target)=return_dataset(args)
# args.lr=0.01
len_target_loader=len(target_loader)
len_target_loader_unl=len(target_loader_unl)
netF, netC, netD = get_model(args)
# print(get_para_num(netF))
# # print(get_para_num(netF.bottle_neck))
# print(get_para_num(netC))
# print(get_para_num(netD))
args.modelpath = args.output_dir + '/source_F_val.pt'
netF.load_state_dict(torch.load(args.modelpath))
args.modelpath = args.output_dir + '/source_C_val.pt'
netC.load_state_dict(torch.load(args.modelpath))
netF = netF.cuda()
netC = netC.cuda()
param_group = []
# for k, v in netF.named_parameters():
# v.requires_grad=True
# param_group += [{'params': v, 'lr': args.lr}]
for k, v in netF.features.named_parameters():
v.requires_grad=True
param_group += [{'params': v, 'lr': args.lr}]
for k, v in netF.bottle_neck.named_parameters():
v.requires_grad=True
param_group += [{'params': v, 'lr': args.lr*10}]
if args.update_cls==1:
for k, v in netC.named_parameters():
v.requires_grad = True
param_group += [{'params': v, 'lr': args.lr}]
else:
for k, v in netC.named_parameters():
v.requires_grad = False
# param_group += [{'params': v, 'lr': args.lr}]
# for k, v in netD.named_parameters():
# v.requires_grad=True
# param_group += [{'params': v, 'lr': args.lr*10}]
optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
factor=0.8, patience=20,
verbose=True, min_lr=1e-5)
# vat_loss=VATLoss()
scaler = torch.cuda.amp.GradScaler()
netF=nn.DataParallel(netF,device_ids=[0])
netC=nn.DataParallel(netC,device_ids=[0])
max_pred_acc=-1
best_test_acc = -1
acc,acc_val=0,0
best_F,bestC,bestD=None,None,None
first_epoch_acc=-1
psudo_acc=-1
for epoch in (range(args.max_epoch_target)):
if not args.max_num_per_class==0:
args.cur_epoch=epoch+1
# target_loader, target_loader_unl, psudo_acc= return_dataloader_by_UPS(args,netF,netC,target_loader_unl)
target_loader, target_loader_unl, psudo_acc= return_dataloader_by_progressive_UPS(args,netF,netC,target_loader_unl)
# psudo_acc=0
# active_target_loader, active_target_loader_unl= target_loader, target_loader_unl
#iter_target = iter(active_target_loader)
#iter_target_unl = iter(active_target_loader_unl)
total_losses, classifier_losses,unl_losses,ws,entropy_losses,div_losses,ls= [], [],[],[],[],[],[]
netF.train()
netC.train()
#for step, ((inputs_x, targets_x), ((inputs_u,inputs_u2),_)) in tqdm(enumerate(zip(cycle(active_target_loader),active_target_loader_unl)), leave=False):
for step, ((inputs_x, targets_x), ((inputs_u,inputs_u2),_)) in enumerate(zip(cycle(target_loader),target_loader_unl)):
# for step, (unlabeled_target, _) in (enumerate(iter_target_unl)):
# if unlabeled_target.size(0) == 1:
# continue
start=time.time()
#if step % len_target_loader==0:
#print("dataloader")
#iter_target=iter(active_target_loader)
# inputs_test = inputs_test
#inputs_x,targets_x=next(iter_target)
batch_size = inputs_x.size(0)
# print(batch_size,inputs_x.shape,inputs_u.shape,inputs_u2.shape)
# Transform label to one-hot
targets_x = torch.zeros(batch_size, args.class_num).scatter_(1, targets_x.view(-1, 1).long(), 1)
#print(time.time()-start)
inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
inputs_u = inputs_u.cuda()
inputs_u2 = inputs_u2.cuda()
#print(time.time()-start)
entropy_a=None
#with torch.cuda.amp.autocast():
with torch.no_grad():
# compute guessed labels of unlabel samples
size=inputs_u.shape[0]
#outputs_u = netC(netF(inputs_u))
#outputs_u2 = netC(netF(inputs_u2))
input_cat=torch.cat([inputs_u, inputs_u2], dim=0)
out_cat=netC(netF(input_cat))
outputs_u=out_cat[0:size]
outputs_u2=out_cat[size:]
p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
entropy_a = Entropy(p)
pt = p ** (1 / args.T)
targets_u = pt / pt.sum(dim=1, keepdim=True)
targets_u = targets_u.detach()
# mixup
all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)
#for self mixup
# all_inputs = torch.cat([ inputs_u, inputs_u2], dim=0)
# all_targets = torch.cat([targets_u, targets_u], dim=0)
l = np.random.beta(args.alpha, args.alpha)
l = max(l, 1 - l)
idx = torch.randperm(all_inputs.size(0))
input_a, input_b = all_inputs, all_inputs[idx]
target_a, target_b = all_targets, all_targets[idx]
if args.nonlinear==1:
with torch.no_grad():
# softmax_out = nn.Softmax(dim=1)(netC(netF(input_a)))
# # entropy_a = torch.mean((1-F.sigmoid(Entropy(softmax_out))),dim=0, keepdim=True)*0.1
# entropy_a = Entropy(softmax_out)
# entropy_a = (F.sigmoid(Entropy(softmax_out)))*1.5
# # entropy_a[0:args.batch_size]=0
# entropy_a=entropy_a.unsqueeze(1)
# print("a",entropy_a,entropy_a.shape,target_a.shape)
# softmax_out = nn.Softmax(dim=1)(netC(netF(input_b)))
# # entropy_b =torch.mean((1-F.sigmoid(Entropy(softmax_out))),dim=0,keepdim=True)*0.1
# entropy_b =(F.sigmoid(Entropy(softmax_out)))
# entropy_b=entropy_b.unsqueeze(1)
# all_entropy=torch.cat([entropy_a,entropy_b], dim=1)
# print(torch.mean(all_entropy),all_entropy.shape)
# nn.Softmax(dim=1)(all_entropy)[:,0]
# l=
alpha=np.exp(torch.mean(entropy_a).item())
# alpha2=(torch.mean(entropy_a).item())
l = np.random.beta(alpha, 0.5)
# l = min(l, 1 - l)
# l=l.unsqueeze(1)
# print(torch.mean(l),l.shape)
# print("b",entropy_b,entropy_b.shape,input_b.shape)
# mixed_input = l * input_a + (1 - l) * input_b
# # mixed_input = l.reshape(-1,1,1,1) * input_a + (1-l.reshape(-1,1,1,1)) * input_b
# mixed_target = l*target_a + target_b*(1-l)
# else:
mixed_input = l * input_a + (1 - l) * input_b
mixed_target = l * target_a + target_b*(1-l)
# interleave labeled and unlabed samples between batches to get correct batchnorm calculation
mixed_input = list(torch.split(mixed_input, batch_size))
#mixed_input = interleave(mixed_input, batch_size)
ls.append(l)
#logits = [netC(netF(mixed_input[0]))]
#for input in mixed_input[1:]:
#logits.append(netC(netF(input)))
logits=netC(netF(torch.cat(mixed_input,dim=0)))
logits=list(torch.split(logits, batch_size))
# put interleaved samples backd
#logits = interleave(logits, batch_size)
logits_x = logits[0]
logits_u = torch.cat(logits[1:], dim=0)
targets_x = mixed_target[:batch_size]
targets_u = mixed_target[batch_size:]
probs_u = torch.softmax(logits_u, dim=1)
Lx = -torch.mean(torch.sum(F.log_softmax(logits_x, dim=1) * targets_x, dim=1))
#withted unlabeled loss
# wnl
if args.wnl==1:
with torch.no_grad():
prb_u = torch.softmax(logits_u, dim=1)
prb_u_entropy=((Entropy(prb_u))).reshape(-1,1)
# print("min",torch.min(targets_u_entropy),"max",torch.max(targets_u_entropy), "mean",torch.mean(targets_u_entropy))
prb_u_entropy=prb_u_entropy.repeat(1,targets_u.shape[1])
Lu = torch.mean((((probs_u - targets_u) )** 2) * ((1/prb_u_entropy**2)))
else :
Lu = torch.mean((probs_u - targets_u) ** 2)
# print(step / len_target_loader_unl)
# if args.wnl==1: w=args.lambda_u
# else : w = args.lambda_u * exp_rampup(epoch + step / len_target_loader_unl, args.max_epoch_target)
w = args.lambda_u * exp_rampup(epoch + step / len_target_loader_unl, args.max_epoch_target)
# w = args.lambda_u
ws.append(w)
classifier_losses.append(Lx.item())
unl_losses.append(Lu.item())
total_loss = Lx + w * Lu
# im_loss = 0
# softmax_out = nn.Softmax(dim=1)(netC(netF(inputs_u)))
# un_labeled_entropy = torch.mean(Entropy(softmax_out))
# im_loss+=args.unlent*un_labeled_entropy
# entropy_losses.append(un_labeled_entropy.item())
# entropy_losses.append(0)
# if args.nonlinear==1:
# args.alpha=un_labeled_entropy.item()
# msoftmax = softmax_out.mean(dim=0)
# tmp = torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
# div_losses.append(tmp.item())
# im_loss -= args.div_w * tmp
#
# # im_loss= im_loss
# # scaler.scale(im_loss).backward()
# # im_loss.backward()
# total_loss=total_loss+im_loss* args.im
total_losses.append(total_loss.item())
# total_loss
#scaler.scale(total_loss).backward()
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
#scaler.step(optimizer)
#scaler.update()
#optimizer.zero_grad()
netF.eval()
netC.eval()
if epoch %5==0:
acc, _ = cal_acc(target_loader_test, netF, netC)
train_acc, _ = cal_acc(target_loader, netF, netC)
acc_val, _ = cal_acc(target_loader_val, netF, netC)
training_process={"epoch": epoch,"train_acc":train_acc, "val_acc":acc_val, "test_acc":acc}
# training_process["train_acc"]=train_acc
# training_process["train_acc"]=train_acc
write_csv_file("training-process.csv",training_process)
log_str = 'tra_tgt: {}, I:{}/{}; test_acc = {:.2f}% ,acc_val = {:.2f}% ,total_l: {:.5f}, Lx: {:.5f}, Lu: {:.8f}, w: {:.5f},ent_l: {:.3f}, l: {:.3f} psudo_acc {:.2f} gpu {}'.format(
args.s + "2" + args.t, epoch + 1, args.max_epoch_target, acc * 100, acc_val * 100,
np.mean(total_losses), np.mean(classifier_losses), np.mean(unl_losses), np.mean(ws),
np.mean(entropy_losses), np.mean(ls),psudo_acc,args.gpu_id)
args.out_file.write(log_str + '\n')
args.out_file.flush()
print(log_str + '\n')
scheduler.step(np.mean(classifier_losses))
# print("unl_w", args.unl_w)
if max_pred_acc<acc_val:
best_F=deepcopy(netF)
bestC=deepcopy(netC)
max_pred_acc=acc_val
if best_test_acc<acc:
best_test_acc=acc
if epoch==0:
first_epoch_acc=acc
# if acc<first_epoch_acc-0.15:# stop the program if acc drop down.
# var = vars(args)
# var["val_acc"] = max_pred_acc
# var["test_acc"] = acc
# var["best_test_acc"] = best_test_acc
# # write_shared_file("run_officehome.txt",[args.out_file.name+': val_acc:{:.2f}, test_acc:{:.2f} \n'.format(max_pred_acc,acc)])
# write_csv_file("run_office-home_2022.05.17.csv", var)
# return
# max_cluter_acc=cluster_acc
name='tar_' + args.s+"2"+args.t+"_lr"+str(args.lr)+ "MNPC"+str(args.max_num_per_class)+ "_im"+str(args.im)+ "_u"+str(args.lambda_u)+"_unlent"+str(args.unlent)+"_nonlinear"+str(args.nonlinear)+"_wnl"+str(args.wnl)+"_alpha"+str(args.alpha)+"_num"+str(args.num)+"_"
torch.save(best_F, osp.join(args.output_dir, name+"target_F.pt"))
# torch.save netB.state_dict(), osp.join(args.output_dir, "target_B.pt"))
torch.save(bestC, osp.join(args.output_dir, name+"target_C.pt"))
acc, _ = cal_acc(target_loader_test, best_F, bestC)
log_str="test_acc: {:.4f} ".format(acc)
args.out_file.write(log_str + '\n')
args.out_file.flush()
print(log_str + '\n')
# file_name1 = osp.join(args.output_dir, 'tar_' + args.s + "2" + args.t + "_lr" + str(args.lr)
# + "MNPC" + str(args.max_num_per_class) + "_im" + str(args.im)
# + "_u" + str(args.lambda_u) + "_unlent" + str(args.unlent) + "_num" + str(
# args.num) + 'embeding_adaptation.tsv')
# file_name2 = osp.join(args.output_dir, 'tar_' + args.s + "2" + args.t + "_lr" + str(args.lr)
# + "MNPC" + str(args.max_num_per_class) + "_im" + str(args.im)
# + "_u" + str(args.lambda_u) + "_unlent" + str(args.unlent) + "_num" + str(
# args.num) + 'meta_adaptation.tsv')
# np.savetxt(file_name1, emb, delimiter='\t')
# np.savetxt(file_name2, labels, delimiter='\t')
var=vars(args)
var["val_acc"]=max_pred_acc
var["test_acc"]=acc
var["best_test_acc"] = best_test_acc
write_csv_file("run_office-home_2022.05.17.csv",var)
return netF, netC
def Entropy(input_):
# bs = input_.size(0)
entropy = -input_ * torch.log(input_ + 1e-5)
entropy = torch.sum(entropy, dim=1)
return entropy
def print_args(args):
s = "==========================================\n"
for arg, content in args.__dict__.items():
s += "{}:{}\n".format(arg, content)
return s
@contextlib.contextmanager
def _disable_tracking_bn_stats(model):
def switch_attr(m):
if hasattr(m, 'track_running_stats'):
m.track_running_stats ^= True
model.apply(switch_attr)
yield
model.apply(switch_attr)
def _l2_normalize(d):
d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
return d
class VATLoss(nn.Module):
def __init__(self, xi=10.0, eps=1, ip=1):
"""VAT loss
:param xi: hyperparameter of VAT (default: 10.0)
:param eps: hyperparameter of VAT (default: 1.0)
:param ip: iteration times of computing adv noise (default: 1)
"""
super(VATLoss, self).__init__()
self.xi = xi
self.eps = eps
self.ip = ip
def forward(self, netF,netC, x):
# print(torch.mean(x))
with torch.no_grad():
pred = F.softmax(netC(netF(x)), dim=1)
# softmax_out = nn.Softmax(dim=1)(pred)
# entropy = Entropy(softmax_out)
# prepare random unit tensor
d = torch.rand(x.shape).sub(0.5).to(x.device)
d = _l2_normalize(d)
# print("d",torch.mean(d))
with _disable_tracking_bn_stats(netF):
# calc adversarial direction
for _ in range(self.ip):
d.requires_grad_()
pred_hat = netC(netF(x + self.xi * d))
logp_hat = F.log_softmax(pred_hat, dim=1)
adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
adv_distance.backward()
d = _l2_normalize(d.grad)
netF.zero_grad()
netC.zero_grad()
# calc LDS
# r_adv = d * self.eps
# print("d, entropy", d.shape, torch.mean(entropy))
# entropy=entropy.unsqueeze(1).unsqueeze(2).unsqueeze(3).repeat(1, 3, 224, 224)
# entropy=torch.sigmoid(entropy)
# r_adv = d * self.eps*(10/torch.exp(entropy)
r_adv = d * self.eps
# r_adv = d * self.eps*entropy
pred_hat = netC(netF((x + r_adv)))
logp_hat = F.log_softmax(pred_hat, dim=1)
lds = F.kl_div(logp_hat, pred, reduction='batchmean')
return lds
def test_target_svd(args):
from collections import OrderedDict
source_loader, _,_, target_loader_unl, _, target_loader_test, class_list = return_dataset(args)
netF, netC, _ = get_model(args)
name='tar_' + args.s+"2"+args.t+"_lr"+str(args.lr)+ "MNPC"+str(args.max_num_per_class)+ "_im"+str(args.im)+ "_u"+str(args.lambda_u)+"_unlent"+str(args.unlent)+"_nonlinear"+str(args.nonlinear)+"_wnl"+str(args.wnl)+"_alpha"+str(args.alpha)+"_num"+str(args.num)+"_"
if args.nonlinear==1:
args.modelpath = args.output_dir + '/target_F.pt'
else:
args.modelpath = args.output_dir + '/{}target_F.pt'.format(name)
F_dic=torch.load(args.modelpath)
new_F_dic=OrderedDict()
for k,v in F_dic.state_dict().items():
mname=k.replace('module.', '')
new_F_dic[mname]=v
netF.load_state_dict(new_F_dic)
if args.nonlinear==1:
args.modelpath = args.output_dir + '/target_C.pt'
else:
args.modelpath = args.output_dir + '/{}target_C.pt'.format(name)
C_dic=torch.load(args.modelpath)
new_C_dic=OrderedDict()
for k,v in C_dic.state_dict().items():
mname=k.replace('module.', '')
new_C_dic[mname]=v
netC.load_state_dict(new_C_dic)
netC=netC.cuda()
netF=netF.cuda()
netF.eval()
netC.eval()
all_fea,all_output,all_label=compute_stride(target_loader_unl,netF,netC,args,istarget=True)
# print("ddddd")
len_data=len(all_fea)
bs=args.batch_size
all_fea=all_fea.reshape(-1,bs,all_fea.shape[1])
print('begin',all_fea.shape)
u,s,v= torch.linalg.svd(all_fea.cuda().permute(0,2,1))
print("done",u.shape,s.shape,v.shape)
target_S= torch.mean(s,dim=0)
target_U= torch.mean(u,dim=0)
all_fea,all_output,all_label=compute_stride(source_loader,netF,netC,args)
len_data=len(all_fea)
bs=args.batch_size
all_fea=all_fea.reshape(-1,bs,all_fea.shape[1])
print('begin',all_fea.shape)
u,s,v= torch.linalg.svd(all_fea.cuda().permute(0,2,1))
print("done",u.shape,s.shape,v.shape)
source_S= torch.mean(s,dim=0)
source_U= torch.mean(u,dim=0)
p_s, cospa, p_t = torch.svd(torch.mm(source_U.t(), target_U))
sinpa = torch.sqrt(1-torch.pow(cospa,2))
subspace_distance=torch.norm(sinpa,1)
source_S=source_S.reshape(-1).cpu().numpy()
target_S=target_S.reshape(-1).cpu().numpy()
# print("subspace_distance: ", subspace_distance)
log_str = "subspace_distance:{:.4f}".format(subspace_distance)
args.out_file.write(log_str + '\n')
args.out_file.flush()
print(log_str + '\n')
# source_path=args.output_dir+"/"+name+args.s+".csv"
# with open(file_path,'w',encoding='utf-8') as f:
np.savetxt(args.output_dir+"/"+name+args.s+".csv",source_S,delimiter=",")
np.savetxt(args.output_dir+"/"+name+args.t+".csv",target_S,delimiter=",")
def compute_stride(loader,netF,netC,args,istarget=False):
start_test = True
with torch.no_grad():
iter_test = iter(loader)
for i in range(len(loader)):
if i%5==0:
if istarget:
data, _ = iter_test.next()
else :
data = iter_test.next()
inputs = data[0]
# labels = data[1]
inputs = inputs.cuda()
feas = netF(inputs)
outputs = netC(feas)
# yield (feas,outputs,labels)
if start_test:
all_fea = feas.float().cpu()
all_output = outputs.float().cpu()
# all_label = labels.float()
start_test = False
else:
all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
# all_label = torch.cat((all_label, labels.float()), 0)
# all_output = nn.Softmax(dim=1)(all_output)
# # ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
# # unknown_weight = 1 - ent / np.log(args.class_num)
# _, predict = torch.max(all_output, 1)
#
# accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
#
# print("pred_accuracylabel: ",accuracy)
return all_fea,all_output,None
def compute(loader,netF,netC,args):
start_test = True
with torch.no_grad():
iter_test = iter(loader)
for _ in range(len(loader)):
data = iter_test.next()
inputs = data[0]
labels = data[1]
inputs = inputs.cuda()
feas = netF(inputs)
outputs = netC(feas)
# yield (feas,outputs,labels)
if start_test:
all_fea = feas.float().cpu()
all_output = outputs.float().cpu()
all_label = labels.float()
start_test = False
else:
all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
# all_output = nn.Softmax(dim=1)(all_output)
# # ent = torch.sum(-all_output * torch.log(all_output + args.epsilon), dim=1)
# # unknown_weight = 1 - ent / np.log(args.class_num)
# _, predict = torch.max(all_output, 1)
#
# accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
#
# print("pred_accuracylabel: ",accuracy)
return all_fea,all_output,all_label
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon=0.1, use_gpu=True, size_average=True):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.use_gpu = use_gpu
self.size_average = size_average
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets,T=1.0):
log_probs = self.logsoftmax(inputs/T)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
if self.use_gpu: targets = targets.cuda()
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
# print((targets * log_probs).mean(0).shape)
if self.size_average:
loss = (- targets * log_probs).mean(0).sum()
else:
loss = (- targets * log_probs).sum(1)
return loss
def cal_acc(loader, netF, netC):
start_test = True
# with torch.cuda.amp.autocast:
with torch.no_grad():
iter_test = iter(loader)
for i in range(len(loader)):
data = iter_test.next()
inputs = data[0]
labels = data[1]
inputs = inputs.cuda()
labels=labels.cuda()#2020 07 06
# inputs = inputs
outputs= netC(netF(inputs))
# outputs,margin_logits = netC(netF(inputs),labels)
labels=labels.cpu()#2020 07 06
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
mean_ent = torch.mean(Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
return accuracy, mean_ent
def cal_acc11111111111111111(loader, netF, netC):
# from tensorboardX import SummaryWriter
# writer = SummaryWriter(log_dir='./logs_kk', comment='cat image') # 这里的logs要与--logdir的参数一样
import random
import numpy as np
start_test = True
# with torch.cuda.amp.autocast:
with torch.no_grad():
iter_test = iter(loader)
for i in range(len(loader)):
data = iter_test.next()
inputs = data[0]
labels = data[1]
inputs = inputs.cuda()
labels=labels.cuda()#2020 07 06
# inputs = inputs
embs=netF(inputs)
outputs= netC(embs)
# outputs,margin_logits = netC(netF(inputs),labels)
labels=labels.cpu()#2020 07 06
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
all_embs=embs.float().cpu()
start_test = False
else:
all_embs= torch.cat((all_embs, embs.float().cpu()), 0)
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
mean_ent = torch.mean(Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
index=random.sample(list(range(len(all_label))),1000)
# print(all_embs.shape,all_label.shape)
# writer.add_embedding(
# all_embs[index,:],
# metadata=all_label[index],
# # label_img=all_label
# )
# writer.close() # 执行close立即刷新,否则将每120秒自动刷新
return accuracy, mean_ent,all_embs[index,:].cpu().numpy(),all_label[index].cpu().numpy()
def get_model(args):
netF,netC,netD=None,None,None
if args.net == 'resnet34':
netF = resnet34(args=args)
inc = args.bottleneck
netC=Predictor(num_class=args.class_num,inc=inc,norm_feature=args.norm_feature,temp=args.temp)
# netC=LMCL_loss(args.class_num, inc, s=1.00, m=0.4)
elif args.net == 'resnet50':
netF = resnet50(args=args)
inc = args.bottleneck
netC=Predictor(num_class=args.class_num,inc=inc,norm_feature=args.norm_feature,temp=args.temp)
# netC=LMCL_loss(args.class_num, inc, s=1.00, m=0.4)
elif args.net == "alexnet":
inc = args.bottleneck
netF = AlexNetBase(bootleneck_dim=inc)
netC = Predictor(num_class=args.class_num, inc=inc,norm_feature=args.norm_feature,temp=args.temp)
# netC = Predictor(num_class=args.class_num, inc=inc)
elif args.net == "vgg":
inc = args.bottleneck
netF = VGGBase(bootleneck_dim=inc)
# inc = 25088
netC = Predictor(num_class=args.class_num, inc=inc,norm_feature=args.norm_feature,temp=args.temp)
else:
raise ValueError('Model cannot be recognized.')
print(get_para_num(netF))
print(get_para_num(netC))
return netF,netC,netD
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Domain Adaptation')
parser.add_argument('--gpu_id', type=int, default=1, help="device id to run")
parser.add_argument('--net', type=str, default="vgg", choices=['vgg',"alexnet","resnet34",'LeNet', 'resnet50', 'ResNet101'])
parser.add_argument('--s', type=str, default="Real_World", help="source office_home :Art Clipart Product Real_World")
parser.add_argument('--t', type=str, default="Product", help="target Art Clipart Product Real_World")
parser.add_argument('--max_epoch_source', type=int, default=30, help="maximum epoch ")
# parser.add_argument('--skd_src', type=int, default=1, help="maximum epoch ")
parser.add_argument('--max_epoch_target', type=int, default=30, help="maximum epoch ")
parser.add_argument('--num', type=int, default=1, help="labeled_data per class")
parser.add_argument('--train', type=int, default=1, help="if to train")
# parser.add_argument('--adv', type=int, default=0, help="if to adversarial")
parser.add_argument('--batch_size', type=int, default=32, help="batch_size")
parser.add_argument('--class_num', type=int, default=65, help="batch_size",choices=[65,10,31,126])
parser.add_argument('--worker', type=int, default=16, help="number of workers")
# parser.add_argument('--dset', type=str, default='u2m', choices=['u2m', 'm2u', 's2m'])
parser.add_argument('--dataset', type=str, default='office-home', choices=['office-home', 'multi','digits', 'Office-31'])
parser.add_argument('--lr', type=float, default=0.001, help="learning rate")
parser.add_argument('--seed', type=int, default=2021, help="random seed")
parser.add_argument('--update_cls', type=int, default=0, help="random seed")
parser.add_argument('--max_num_per_class', type=int, default=0, help="random seed")
parser.add_argument('--norm_feature', type=int, default=0, help="random seed")
parser.add_argument('--par', type=float, default=1)
parser.add_argument('--temp', type=float, default=0.05)
parser.add_argument('--alpha', type=float, default=0.75)
parser.add_argument('--lambda_u', type=float, default=200)
parser.add_argument('--im', type=float, default=1)
parser.add_argument('--T', type=float, default=0.5)
parser.add_argument('--bottleneck', type=int, default=256)
parser.add_argument('--smooth', type=float, default=0.1)
parser.add_argument('--output', type=str, default='2021_02_03Office-31')
parser.add_argument('--epsilon', type=float, default=1e-5)
# parser.add_argument('--lent', type=float, default=0.0)
parser.add_argument('--unlent', type=float, default=1)
parser.add_argument('--unl_w', type=float, default=0)
parser.add_argument('--vat_w', type=float, default=0)
parser.add_argument('--div_w', type=float, default=1)
parser.add_argument('--uda', type=int, default='0', choices=[0, 1])
parser.add_argument('--nonlinear', type=int, default='0', choices=[0, 1])
parser.add_argument('--wnl', type=int, default='0', choices=[0, 1])
args = parser.parse_args()
# print("uda",args.uda)
# args.max_epoch_target=100
# set_gpu(args.gpu_id)
# if args.lambda_u>200:
# exit()
args.gpu_id=getAvaliableDevice(min_mem=24000)
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
# print("use GPU",args.gpu_id)
#os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"
setup_seed(args.seed)
#torch.backends.cudnn.benchmark = True
import warnings
warnings.filterwarnings("ignore")
current_folder = "./ssda"
# current_folder = "./"
args.output_dir = osp.join(current_folder, args.output, 'seed' + str(args.seed), args.dataset,args.s+"_norm"+str(args.norm_feature)+"_temp"+str(args.temp)+"_lr"+str(args.lr))
# args.output_dir = osp.join(current_folder, args.output, 'seed' + str(args.seed), args.dataset,args.s+"_norm"+str(args.norm_feature)+"_temp"+str(args.temp))
if not osp.exists(args.output_dir):
os.system('mkdir -p ' + args.output_dir)
if not osp.exists(args.output_dir):
os.mkdir(args.output_dir)
# with torch.cuda.device(args.gpu_id):
# args.lr = 0.0001
# test_target_svd(args)
if args.train==1:
args.out_file = open(osp.join(args.output_dir, 'log_src_val.txt'), 'w')
args.out_file.write(print_args(args) + '\n')
args.out_file.flush()
train_source(args)
else:
args.out_file = open(osp.join(args.output_dir, 'tar_' + args.s+"2"+args.t+"_lr"+str(args.lr)
+ "MNPC"+str(args.max_num_per_class)+ "_im"+str(args.im)
+ "_u"+str(args.lambda_u)+"_unlent"+str(args.unlent)
+"_nonlinear"+str(args.nonlinear)+"_wnl"+str(args.wnl)+"_alpha"+str(args.alpha)
+"_num"+str(args.num)+'.txt'), 'w')
test_target(args)
args.out_file.write(print_args(args) + '\n')
args.out_file.flush()
train_target(args)
test_target_svd(args)
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
resnet_dict = {"ResNet18": models.resnet18, "ResNet34": models.resnet34, "ResNet50": models.resnet50,
"ResNet101": models.resnet101, "ResNet152": models.resnet152}
def get_para_num(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
class ResNetEncoder(nn.Module):
def __init__(self, resnet_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
super(ResNetEncoder, self).__init__()
model_resnet = resnet_dict[resnet_name](pretrained=True)
self.conv1 = model_resnet.conv1
self.bn1 = model_resnet.bn1
self.relu = model_resnet.relu
self.maxpool = model_resnet.maxpool
self.layer1 = model_resnet.layer1
self.layer2 = model_resnet.layer2
self.layer3 = model_resnet.layer3
self.layer4 = model_resnet.layer4
self.avgpool = model_resnet.avgpool
self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, \
self.layer1, self.layer2, self.layer3, self.layer4,
self.avgpool
)
self.in_features=2048*1*1
# self.use_bottleneck = use_bottleneck
# self.new_cls = new_cls
# if new_cls:
# if self.use_bottleneck:
# self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim)
# # self.fc = nn.Linear(bottleneck_dim, class_num)
# self.bottleneck.apply(init_weights)
# # self.fc.apply(init_weights)
# self.__in_features = bottleneck_dim
# else:
# self.fc = nn.Linear(model_resnet.fc.in_features, class_num)
# self.fc.apply(init_weights)
# self.__in_features = model_resnet.fc.in_features
# else:
# self.fc = model_resnet.fc
# self.__in_features = model_resnet.fc.in_features
def forward(self, x):
x = self.feature_layers(x)
# print("encoder",x.shape)
x = x.view(x.size(0), -1)
# if self.use_bottleneck and self.new_cls:
# x = self.bottleneck(x)
# y = self.fc(x)
return x
def output_num(self):
return self.__in_features
def get_parameters(self):
if self.new_cls:
if self.use_bottleneck:
parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
{"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}
]
else:
parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}
]
else:
parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
return parameter_list
class ResNetDecoder(nn.Module):
def __init__(self, bottleneck_dim=256, latent_dim=[2048, 1, 1]):
super(ResNetDecoder, self).__init__()
self.lin2 = nn.Linear(bottleneck_dim, latent_dim[0] * latent_dim[1] * latent_dim[2])
self.relu=nn.ReLU()
self.bn1=nn.BatchNorm1d(latent_dim[0] * latent_dim[1] * latent_dim[2])
self.drop=nn.Dropout(0.5)
self.lin2.apply(init_weights)
self.latent_dim = latent_dim
self.bottleneck_dim = bottleneck_dim
self.t_conv1 = nn.ConvTranspose2d(2048, 256, kernel_size=3, stride=1)#size +7
self.t_conv11 = nn.ConvTranspose2d(256, 256, kernel_size=3, stride=2) # size +7
self.t_conv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)#size *2
self.t_conv3 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)#size *2
self.t_conv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)#size *2
self.t_conv5 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)#size *2
self.t_conv6 = nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)#size *2
self.t_conv1.apply(init_weights)
self.t_conv11.apply(init_weights)
self.t_conv2.apply(init_weights)
self.t_conv3.apply(init_weights)
self.t_conv4.apply(init_weights)
self.t_conv5.apply(init_weights)
self.t_conv6.apply(init_weights)
self.t_conv=nn.Sequential(
self.t_conv1,nn.ReLU(),nn.BatchNorm2d(256),nn.Dropout(0.3),
self.t_conv11,nn.ReLU(),nn.BatchNorm2d(256),nn.Dropout(0.3),
self.t_conv2,nn.ReLU(),nn.BatchNorm2d(128),nn.Dropout(0.3),
self.t_conv3,nn.ReLU(),nn.BatchNorm2d(128),nn.Dropout(0.3),
self.t_conv4,nn.ReLU(),nn.BatchNorm2d(64),nn.Dropout(0.3),
self.t_conv5,nn.ReLU(),nn.BatchNorm2d(64),nn.Dropout(0.3),
self.t_conv6,
nn.Sigmoid()
)
def forward(self, x):
x = self.drop(self.bn1(self.relu(self.lin2(x))))
x = x.view(-1, self.latent_dim[0], self.latent_dim[1], self.latent_dim[2])
x=self.t_conv(x)
# print("decoder:",x.shape)
return x
class LeNetEncoder(nn.Module):
def __init__(self,use_bottleneck=True, bottleneck_dim=256):
super(LeNetEncoder,self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
# self.bn1=nn.BatchNorm2d(20),
self.conv2 = nn.Conv2d(20, 50, 5)
# self.bn2=nn.BatchNorm2d(50),
self.pool = nn.MaxPool2d(2, 2)
self.encoder = nn.Sequential(
self.conv1,
# self.bn1,
nn.Dropout2d(0.1),
nn.ReLU(),
self.pool,
self.conv2,
# self.bn2,
nn.Dropout2d(0.3),
nn.ReLU(),
self.pool
)
self.latent_dim=[50,4,4]
self.bottleneck_dim=bottleneck_dim
self.lin1 = nn.Linear(800, 256)
# self.lin2 = nn.Linear(256, 800)
self.lin1.apply(init_weights)
self.bottle = nn.Sequential(
self.lin1,
nn.BatchNorm1d(self.bottleneck_dim, affine=True),
nn.ReLU(),
nn.Dropout(p=0.3)
)
def forward(self,x):
x=self.encoder(x)
x=self.bottle(x.view(x.size(0), -1))
return x
class LeNetDecoder(nn.Module):
def __init__(self, bottleneck_dim=256,latent_dim=[50,4,4]):
super(LeNetDecoder,self).__init__()
self.lin2 = nn.Linear(bottleneck_dim, latent_dim[0]*latent_dim[1]*latent_dim[2])
self.lin2.apply(init_weights)
self.latent_dim=latent_dim
self.bottleneck_dim=bottleneck_dim
self.t_conv1 = nn.ConvTranspose2d(50, 40, kernel_size=2, stride=2)
self.t_conv2 = nn.ConvTranspose2d(40, 20, kernel_size=5)
self.t_conv3 = nn.ConvTranspose2d(20, 10, kernel_size=2, stride=2)
self.t_conv4 = nn.ConvTranspose2d(10, 1, kernel_size=5)
self.t_conv1.apply(init_weights)
self.t_conv2.apply(init_weights)
self.t_conv3.apply(init_weights)
self.t_conv4.apply(init_weights)
def forward(self,x):
x = self.lin2(x)
x = x.view(-1, self.latent_dim[0], self.latent_dim[1], self.latent_dim[2])
x = torch.relu(self.t_conv1(x))
x = torch.relu(self.t_conv2(x))
x = torch.relu(self.t_conv3(x))
x = torch.sigmoid(self.t_conv4(x))
return x
class ConvAutoencoder(nn.Module):
def __init__(self,encoder_type="LeNet",bottleneck_dim=256,use_bootleneck=True):
super(ConvAutoencoder, self).__init__()
if "LeNet" in encoder_type:
self.encoder=LeNetEncoder(use_bottleneck=use_bootleneck, bottleneck_dim=bottleneck_dim)
self.decoder=LeNetDecoder(bottleneck_dim,latent_dim=self.encoder.latent_dim)
elif "ResNet" in encoder_type:
self.encoder=ResNetEncoder(resnet_name=encoder_type,use_bottleneck=use_bootleneck,
bottleneck_dim=bottleneck_dim,new_cls=True,class_num=10)
self.bottleneck=feat_bootleneck(self.encoder.in_features, bottleneck_dim=bottleneck_dim, type="bn")
self.decoder=ResNetDecoder(bottleneck_dim=bottleneck_dim)
# print(self.encoder)
# print("encoder para:",get_para_num(self.encoder))
# print("bottleneck para:", get_para_num(self.bottleneck))
# print("decoder para:",get_para_num(self.decoder))
def forward(self, x):
x=self.encoder(x)
x=self.bottleneck(x)
z = x
x=self.decoder(x)
return x,None,None,z # none variational, both mu and var are None
def loss_criterion(self, recon_x, x, mu=None, logvar=None):
# print("re",recon_x.shape,"x",x.shape)
# mse=F.mse_loss(recon_x,x,reduction="sum")
bce = F.binary_cross_entropy(recon_x, torch.sigmoid(x), reduction='mean')
# kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return bce
def get_feature_weight(self,recon_x,x):
assert len(recon_x) ==len(x)
# assert len(fea)==len(recon_x) and len(fea)==len(x)
bce = F.binary_cross_entropy(recon_x, torch.sigmoid(x), reduction='none')
bce=torch.mean(bce.view(bce.shape[0],-1),dim=1,keepdim=True)
# normalize
min_v = torch.min(bce)
range_v = torch.max(bce) - min_v
normalised_bce = (bce - min_v) / range_v
#more loss, less weight
weight=1-normalised_bce
reweight=x.shape[0]*(weight/torch.sum(weight))
# print(reweight.shape,reweight)
# weighted_fea=fea*bce
return reweight
def get_finetune_modules(self):
return [self.encoder,self.bottleneck]
def gen_embedding(self, x):
x = self.encoder(x)
x=self.bottleneck(x)
# x = x.view(x.size(0), -1)
return x
#
#
#
# class ConvDenoisingAutoencoder(nn.Module):
# def __init__(self):
# super(ConvDenoisingAutoencoder, self).__init__()
# self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
# self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
#
# self.conv3 = nn.Conv2d(32, 64, 5, stride=5)
# self.conv4 = nn.Conv2d(64, 128, 5, stride=5)
#
# self.pool = nn.MaxPool2d(2, 2)
#
# self.t_conv1 = nn.ConvTranspose2d(128, 64, 5, stride=5)
# self.t_conv2 = nn.ConvTranspose2d(64, 32, 5, stride=5)
#
# self.t_conv3 = nn.ConvTranspose2d(32, 16, 2, stride=2)
# self.t_conv4 = nn.ConvTranspose2d(16, 3, 2, stride=2)
#
# def forward(self, x):
# # add noise
# x_noisy = x + x.data.new(x.size()).normal_(0, 0.1).type_as(x)
# x = torch.relu(self.conv1(x))
# x = self.pool(x)
# x = torch.relu(self.conv2(x))
# x = self.pool(x)
#
# x = torch.relu(self.conv3(x))
# x = torch.relu(self.conv4(x))
#
# x = torch.relu(self.t_conv1(x))
# x = torch.relu(self.t_conv2(x))
# x = torch.relu(self.t_conv3(x))
# x = torch.sigmoid(self.t_conv4(x))
#
# return x
#
# def gen_embedding(self, x):
# x = torch.relu(self.conv1(x))
# x = self.pool(x)
# x = torch.relu(self.conv2(x))
# x = self.pool(x)
#
# x = torch.relu(self.conv3(x))
# x = torch.relu(self.conv4(x))
#
# x = x.view(x.size(0), -1)
#
# return x
class feat_bootleneck(nn.Module):
def __init__(self, feature_dim, bottleneck_dim=256, type="bn"):
super(feat_bootleneck, self).__init__()
self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
# self.relu = nn.ReLU(inplace=True)
# self.dropout = nn.Dropout(p=0.1)
self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
self.bottleneck.apply(init_weights)
self.type = type
def forward(self, x):
x = self.bottleneck(x)
if self.type == "bn":
x = self.bn(x)
# x = self.dropout(x)
return x
import torch.nn.utils.weight_norm as weightNorm
class feat_classifier(nn.Module):
def __init__(self, class_num, bottleneck_dim=256, type="wn"):
super(feat_classifier, self).__init__()
if type == "linear":
# one layer
self.fc1 = nn.Linear(bottleneck_dim, class_num)
#two layer
# self.fc1 = nn.Linear(bottleneck_dim, bottleneck_dim//2)
# self.fc2 = nn.Linear(bottleneck_dim//2,class_num )
# self.fc3 = nn.Linear(bottleneck_dim//2, class_num)
elif type =='wn':
#one layer
# self.fc1=nn.Linear(bottleneck_dim, class_num)
# self.normfc = weightNorm(self.fc1)
#two layer
# self.fc1 = nn.Linear(bottleneck_dim, class_num)
# self.fc2 = nn.Linear(bottleneck_dim // 2, bottleneck_dim // 4)
self.fc1 = weightNorm(nn.Linear(bottleneck_dim, class_num),name="weight")
# self.normfc = (self.fc3)
# self.bn1 =
# self.relu = nn.ReLU(inplace=True)
# self.dropout = nn.Dropout(p=0.5)
self.fc1.apply(init_weights)
# self.fc2.apply(init_weights)
# self.fc3.apply(init_weights)
self.linears=nn.Sequential(
self.fc1,
# self.fc1,self.relu,nn.BatchNorm1d(bottleneck_dim//2,affine=True),nn.Dropout(0.3),
# self.fc2, self.relu, nn.BatchNorm1d(bottleneck_dim // 4, affine=True), nn.Dropout(0.3),
# self.fc3,
# self.fc2,
# self.fc3,
)
def forward(self, x):
# print(x.shape)
# x = self.fc1(x)
# x = self.fc2(x)
# x = self.fc3(x)
# x=F.normalize(x)
x=self.linears(x)
return x
class ConvVariationalAutoencoder(nn.Module):
def __init__(self,encoder_type="LeNet",bottleneck_dim=256,use_bootleneck=True):
super(ConvVariationalAutoencoder, self).__init__()
if "LeNet" in encoder_type:
self.encoder = LeNetEncoder(use_bottleneck=use_bootleneck, bottleneck_dim=bottleneck_dim)
self.decoder = LeNetDecoder(bottleneck_dim, latent_dim=self.encoder.latent_dim)
self.latent_dim = 800
elif "ResNet" in encoder_type:
self.encoder = ResNetEncoder(resnet_name=encoder_type, use_bottleneck=use_bootleneck,
bottleneck_dim=bottleneck_dim, new_cls=True, class_num=10)
# self.bottleneck = feat_bootleneck(self.encoder.in_features, bottleneck_dim=bottleneck_dim, type="bn")
self.latent_dim = 2048
self.decoder = ResNetDecoder(bottleneck_dim=self.latent_dim)
self.trans_mu = nn.Linear(self.latent_dim, self.latent_dim)
self.trans_var = nn.Linear(self.latent_dim, self.latent_dim)
print(self.encoder)
print("encoder para:", get_para_num(self.encoder))
# print("bottleneck para:", get_para_num(self.bottleneck))
print("decoder para:", get_para_num(self.decoder))
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def forward(self, x):
x=self.encoder(x)
mu = self.trans_mu(x.view(-1, self.latent_dim))
logvar = self.trans_var(x.view(-1, self.latent_dim))
z = self.reparameterize(mu, logvar)
x=self.decoder(z)
return x, mu, logvar,z
def loss_criterion(self, recon_x, x, mu, logvar):
# print("re",recon_x.shape,"x",x.shape)
# mse=F.mse_loss(recon_x,x,reduction="sum")
bce = F.binary_cross_entropy(recon_x, torch.sigmoid(x), reduction='mean')
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return bce + kld
def get_feature_weight(self,recon_x,x):
assert len(recon_x) ==len(x)
# assert len(fea)==len(recon_x) and len(fea)==len(x)
bce = F.binary_cross_entropy(recon_x, torch.sigmoid(x), reduction='none')
bce=torch.mean(bce.view(bce.shape[0],-1),dim=1,keepdim=True)
# normalize
min_v = torch.min(bce)
range_v = torch.max(bce) - min_v
normalised_bce = (bce - min_v) / range_v
#more loss, less weight
weight=1-normalised_bce
reweight=x.shape[0]*(weight/torch.sum(weight))
# print(reweight.shape,reweight)
# weighted_fea=fea*bce
return reweight
def get_finetune_modules(self):
return [self.encoder,nn.Sequential(self.trans_mu,self.trans_var)]
def gen_embedding(self, x):
x=self.encoder(x)
mu = self.trans_mu(x.view(-1, self.latent_dim))
logvar = self.trans_var(x.view(-1, self.latent_dim))
z = self.reparameterize(mu, logvar)
return z
# # train_data = datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
# # test_data = datasets.MNIST(root='data', train=False, download=True, transform=transforms.ToTensor())
#
# # batch_size = 512
# # train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=0)
# # test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=0)
#
# # model = ConvAutoencoder().cuda()
# # model = ConvDenoisingAutoencoder().cuda()
# model = ConvVariationalAutoencoder().cuda()
# criterion = nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
#
#
# def train(epochs):
# for epoch in range(1, epochs + 1):
# train_loss = 0
# for _ in range(10):
# images = torch.randn(512, 200, 100, 3)
# images = images.transpose(3, 2).transpose(2, 1)
# images = images.cuda()
# optimizer.zero_grad()
#
# # For Vanilla/Denoising autoencoder
# # outputs = model(images)
# # loss = criterion(outputs, images)
#
# # For variational autoencoder
# outputs, mu, logvar = model(images)
# loss = model.loss_criterion(outputs, images, mu, logvar)
# loss.backward()
# optimizer.step()
# train_loss += loss.item() * images.size(0)
# print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
# if epoch == epochs:
# images = torch.randn(512, 200, 100, 3)
# images = images.transpose(3, 2).transpose(2, 1)
# images = images.cuda()
# embedding = model.gen_embedding(images)
# print(embedding.size())
#
#
# if __name__ == '__main__':
# train(10)
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable
import math
import pdb
def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
resnet_dict = {"ResNet18": models.resnet18, "ResNet34": models.resnet34, "ResNet50": models.resnet50,
"ResNet101": models.resnet101, "ResNet152": models.resnet152}
def grl_hook(coeff):
def fun1(grad):
return -coeff * grad.clone()
return fun1
class ResNetFc(nn.Module):
def __init__(self, resnet_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
super(ResNetFc, self).__init__()
model_resnet = resnet_dict[resnet_name](pretrained=True)
self.conv1 = model_resnet.conv1
self.bn1 = model_resnet.bn1
self.relu = model_resnet.relu
self.maxpool = model_resnet.maxpool
self.layer1 = model_resnet.layer1
self.layer2 = model_resnet.layer2
self.layer3 = model_resnet.layer3
self.layer4 = model_resnet.layer4
self.avgpool = model_resnet.avgpool
self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, \
self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool)
self.use_bottleneck = use_bottleneck
self.new_cls = new_cls
if new_cls:
if self.use_bottleneck:
self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim)
self.fc = nn.Linear(bottleneck_dim, class_num)
self.bottleneck.apply(init_weights)
self.fc.apply(init_weights)
self.__in_features = bottleneck_dim
else:
self.fc = nn.Linear(model_resnet.fc.in_features, class_num)
self.fc.apply(init_weights)
self.__in_features = model_resnet.fc.in_features
else:
self.fc = model_resnet.fc
self.__in_features = model_resnet.fc.in_features
def forward(self, x):
x = self.feature_layers(x)
x = x.view(x.size(0), -1)
if self.use_bottleneck and self.new_cls:
x = self.bottleneck(x)
y = self.fc(x)
return x,y
def output_num(self):
return self.__in_features
def get_parameters(self):
if self.new_cls:
if self.use_bottleneck:
parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
{"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
{"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
else:
parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
{"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
else:
parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
return parameter_list
class RandomLayer(nn.Module):
def __init__(self, input_dim_list=[], output_dim=1024):
super(RandomLayer, self).__init__()
self.input_num = len(input_dim_list)
self.output_dim = output_dim
self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)]
def forward(self, input_list):
return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]
return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list))
for single in return_list[1:]:
return_tensor = torch.mul(return_tensor, single)
return return_tensor
def cuda(self):
super(RandomLayer, self).cuda()
self.random_matrix = [val.cuda() for val in self.random_matrix]
# class LRN(nn.Module):
# def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True):
# super(LRN, self).__init__()
# self.ACROSS_CHANNELS = ACROSS_CHANNELS
# if ACROSS_CHANNELS:
# self.average = nn.AvgPool3d(kernel_size=(local_size, 1, 1),
# stride=1,
# padding=(int((local_size - 1.0) / 2), 0, 0))
# else:
# self.average = nn.AvgPool2d(kernel_size=local_size,
# stride=1,
# padding=int((local_size - 1.0) / 2))
# self.alpha = alpha
# self.beta = beta
#
# def forward(self, x):
# if self.ACROSS_CHANNELS:
# div = x.pow(2).unsqueeze(1)
# div = self.average(div).squeeze(1)
# div = div.mul(self.alpha).add(1.0).pow(self.beta)
# else:
# div = x.pow(2)
# div = self.average(div)
# div = div.mul(self.alpha).add(1.0).pow(self.beta)
# x = x.div(div)
# return x
#
#
# class AlexNet(nn.Module):
#
# def __init__(self, num_classes=1000):
# super(AlexNet, self).__init__()
# self.features = nn.Sequential(
# nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
# nn.ReLU(inplace=True),
# LRN(local_size=5, alpha=0.0001, beta=0.75),
# nn.MaxPool2d(kernel_size=3, stride=2),
# nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2),
# nn.ReLU(inplace=True),
# LRN(local_size=5, alpha=0.0001, beta=0.75),
# nn.MaxPool2d(kernel_size=3, stride=2),
# nn.Conv2d(256, 384, kernel_size=3, padding=1),
# nn.ReLU(inplace=True),
# nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2),
# nn.ReLU(inplace=True),
# nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2),
# nn.ReLU(inplace=True),
# nn.MaxPool2d(kernel_size=3, stride=2),
# )
# self.classifier = nn.Sequential(
# nn.Linear(256 * 6 * 6, 4096),
# nn.ReLU(inplace=True),
# nn.Dropout(),
# nn.Linear(4096, 4096),
# nn.ReLU(inplace=True),
# nn.Dropout(),
# nn.Linear(4096, num_classes),
# )
#
# def forward(self, x):
# x = self.features(x)
# print(x.size())
# x = x.view(x.size(0), 256 * 6 * 6)
# x = self.classifier(x)
# return x
#
#
# def alexnet(pretrained=False, **kwargs):
# r"""AlexNet model architecture from the
# `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
# Args:
# pretrained (bool): If True, returns a model pre-trained on ImageNet
# """
# model = AlexNet(**kwargs)
# if pretrained:
# model_path = './alexnet.pth.tar'
# pretrained_model = torch.load(model_path)
# model.load_state_dict(pretrained_model['state_dict'])
# return model
#
#
# # convnet without the last layer
# class AlexNetFc(nn.Module):
# def __init__(self, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
# super(AlexNetFc, self).__init__()
# model_alexnet = alexnet(pretrained=True)
# self.features = model_alexnet.features
# self.classifier = nn.Sequential()
# for i in range(6):
# self.classifier.add_module("classifier" + str(i), model_alexnet.classifier[i])
# self.feature_layers = nn.Sequential(self.features, self.classifier)
#
# self.use_bottleneck = use_bottleneck
# self.new_cls = new_cls
# if new_cls:
# if self.use_bottleneck:
# self.bottleneck = nn.Linear(4096, bottleneck_dim)
# self.fc = nn.Linear(bottleneck_dim, class_num)
# self.bottleneck.apply(init_weights)
# self.fc.apply(init_weights)
# self.__in_features = bottleneck_dim
# else:
# self.fc = nn.Linear(4096, class_num)
# self.fc.apply(init_weights)
# self.__in_features = 4096
# else:
# self.fc = model_alexnet.classifier[6]
# self.__in_features = 4096
#
# def forward(self, x):
# x = self.features(x)
# x = x.view(x.size(0), -1)
# x = self.classifier(x)
# if self.use_bottleneck and self.new_cls:
# x = self.bottleneck(x)
# y = self.fc(x)
# return x, y
#
# def output_num(self):
# return self.__in_features
#
# def get_parameters(self):
# if self.new_cls:
# if self.use_bottleneck:
# parameter_list = [{"params": self.features.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
# else:
# parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
# else:
# parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
# return parameter_list
#
#
#
# vgg_dict = {"VGG11": models.vgg11, "VGG13": models.vgg13, "VGG16": models.vgg16, "VGG19": models.vgg19,
# "VGG11BN": models.vgg11_bn, "VGG13BN": models.vgg13_bn, "VGG16BN": models.vgg16_bn,
# "VGG19BN": models.vgg19_bn}
#
#
# class VGGFc(nn.Module):
# def __init__(self, vgg_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
# super(VGGFc, self).__init__()
# model_vgg = vgg_dict[vgg_name](pretrained=True)
# self.features = model_vgg.features
# self.classifier = nn.Sequential()
# for i in range(6):
# self.classifier.add_module("classifier" + str(i), model_vgg.classifier[i])
# self.feature_layers = nn.Sequential(self.features, self.classifier)
#
# self.use_bottleneck = use_bottleneck
# self.new_cls = new_cls
# if new_cls:
# if self.use_bottleneck:
# self.bottleneck = nn.Linear(4096, bottleneck_dim)
# self.fc = nn.Linear(bottleneck_dim, class_num)
# self.bottleneck.apply(init_weights)
# self.fc.apply(init_weights)
# self.__in_features = bottleneck_dim
# else:
# self.fc = nn.Linear(4096, class_num)
# self.fc.apply(init_weights)
# self.__in_features = 4096
# else:
# self.fc = model_vgg.classifier[6]
# self.__in_features = 4096
#
# def forward(self, x):
# x = self.features(x)
# x = x.view(x.size(0), -1)
# x = self.classifier(x)
# if self.use_bottleneck and self.new_cls:
# x = self.bottleneck(x)
# y = self.fc(x)
# return x, y
#
# def output_num(self):
# return self.__in_features
#
# def get_parameters(self):
# if self.new_cls:
# if self.use_bottleneck:
# parameter_list = [{"params": self.features.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
# else:
# parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
# else:
# parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
# return parameter_list
#
#
# # For SVHN dataset
# class DTN(nn.Module):
# def __init__(self):
# super(DTN, self).__init__()
# self.conv_params = nn.Sequential(
# nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
# nn.BatchNorm2d(64),
# nn.Dropout2d(0.1),
# nn.ReLU(),
# nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
# nn.BatchNorm2d(128),
# nn.Dropout2d(0.3),
# nn.ReLU(),
# nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
# nn.BatchNorm2d(256),
# nn.Dropout2d(0.5),
# nn.ReLU()
# )
#
# self.fc_params = nn.Sequential(
# nn.Linear(256 * 4 * 4, 512),
# nn.BatchNorm1d(512),
# nn.ReLU(),
# nn.Dropout()
# )
#
# self.classifier = nn.Linear(512, 10)
# self.__in_features = 512
#
# def forward(self, x):
# x = self.conv_params(x)
# x = x.view(x.size(0), -1)
# x = self.fc_params(x)
# y = self.classifier(x)
# return x, y
#
# def output_num(self):
# return self.__in_features
#
#
# class LeNet(nn.Module):
# def __init__(self):
# super(LeNet, self).__init__()
# self.conv_params = nn.Sequential(
# nn.Conv2d(1, 20, kernel_size=5),
# nn.MaxPool2d(2),
# nn.ReLU(),
# nn.Conv2d(20, 50, kernel_size=5),
# nn.Dropout2d(p=0.5),
# nn.MaxPool2d(2),
# nn.ReLU(),
# )
#
# self.fc_params = nn.Sequential(nn.Linear(50 * 4 * 4, 500), nn.ReLU(), nn.Dropout(p=0.5))
# self.classifier = nn.Linear(500, 10)
# self.__in_features = 500
#
# def forward(self, x):
# x = self.conv_params(x)
# x = x.view(x.size(0), -1)
# x = self.fc_params(x)
# y = self.classifier(x)
# return x, y
#
# def output_num(self):
# return self.__in_features
# class AdversarialNetwork(nn.Module):
# def __init__(self, in_feature, hidden_size):
# super(AdversarialNetwork, self).__init__()
# self.ad_layer1 = nn.Linear(in_feature, hidden_size)
# self.ad_layer2 = nn.Linear(hidden_size, hidden_size)
# self.ad_layer3 = nn.Linear(hidden_size, 1)
# self.relu1 = nn.ReLU()
# self.relu2 = nn.ReLU()
# self.dropout1 = nn.Dropout(0.5)
# self.dropout2 = nn.Dropout(0.5)
# self.sigmoid = nn.Sigmoid()
# self.apply(init_weights)
# self.iter_num = 0
# self.alpha = 10
# self.low = 0.0
# self.high = 1.0
# self.max_iter = 10000.0
#
# def forward(self, x):
# if self.training:
# self.iter_num += 1
# coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter)
# x = x * 1.0
# x.register_hook(grl_hook(coeff))
# x = self.ad_layer1(x)
# x = self.relu1(x)
# x = self.dropout1(x)
# x = self.ad_layer2(x)
# x = self.relu2(x)
# x = self.dropout2(x)
# y = self.ad_layer3(x)
# y = self.sigmoid(y)
# return y
#
# def output_num(self):
# return 1
#
# def get_parameters(self):
# return [{"params": self.parameters(), "lr_mult": 10, 'decay_mult': 2}]
from torchvision import models
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
class GradReverse(Function):
def __init__(self, lambd):
self.lambd = lambd
# @staticmethod
def forward(self, x):
return x.view_as(x)
# @staticmethod
def backward(self, grad_output):
return (grad_output * -self.lambd)
def grad_reverse(x, lambd=1.0):
return GradReverse(lambd)(x)
def l2_norm(input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-10)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
class AlexNetBase(nn.Module):
def __init__(self, pret=True,bootleneck_dim=256):
super(AlexNetBase, self).__init__()
model_alexnet = models.alexnet(pretrained=pret)
self.features = nn.Sequential(*list(model_alexnet.
features._modules.values())[:])
self.classifier = nn.Sequential()
for i in range(6):
self.classifier.add_module("classifier" + str(i),
model_alexnet.classifier[i])
self.__in_features = model_alexnet.classifier[6].in_features
self.bottle_neck = feat_bootleneck(feature_dim=4096, bottleneck_dim=bootleneck_dim)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
x=self.bottle_neck(x)
return x
def output_num(self):
return self.__in_feature
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or \
classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
class MLP(nn.Module):
def __init__(self, pret=True,input_dim=3000,bootleneck_dim=256):
super(MLP, self).__init__()
self.features = nn.Sequential(nn.Linear(input_dim,out_features=500))
self.features.apply(init_weights)
self.__in_features = 500
self.bottle_neck = feat_bootleneck(feature_dim=500, bottleneck_dim=bootleneck_dim)
self.bottle_neck.apply(init_weights)
def forward(self, x):
x = self.features(x)
x=self.bottle_neck(x)
return x
def output_num(self):
return self.__in_feature
class LSTM(nn.Module):
def __init__(self, args):
super(LSTM, self).__init__()
self.hidden_dim = args.hidden_dim
self.features = nn.LSTM(args.embedding_dim, self.hidden_dim, num_layers=args.LSTM_layers,
batch_first=True, dropout=args.drop_prob, bidirectional=False)
# self.dropout = nn.Dropout(args.drop_prob)
# self.fc1 = nn.Linear(self.hidden_dim, 256)
# self.fc2 = nn.Linear(256, 32)
# self.fc3 = nn.Linear(32, 2)
self.bottle_neck=feat_bootleneck(self.hidden_dim, bottleneck_dim=args.bottleneck,type="no")
self.args=args
self.embeddings=None
# self.linear = nn.Linear(self.hidden_dim, vocab_size)# 输出的大小是词表的维度,
def set_word2vector(self,pre_weight,finetune=True):
self.embeddings = nn.Embedding.from_pretrained(torch.from_numpy(pre_weight))
# requires_grad指定是否在训练过程中对词向量的权重进行微调
self.embeddings.weight.requires_grad = finetune
def forward(self, input, batch_seq_len, hidden=None):
embeds = self.embeddings(input) # [batch, seq_len] => [batch, seq_len, embed_dim]
embeds = pack_padded_sequence(embeds, batch_seq_len, batch_first=True)
batch_size, seq_len = input.size()
if hidden is None:
h_0 = input.data.new(self.args.LSTM_layers * 1, batch_size, self.hidden_dim).fill_(0).float()
c_0 = input.data.new(self.args.LSTM_layers * 1, batch_size, self.hidden_dim).fill_(0).float()
else:
h_0, c_0 = hidden
output, hidden = self.features(embeds, (h_0, c_0)) # hidden 是h,和c 这两个隐状态
output, _ = pad_packed_sequence(output, batch_first=True)
# output = self.dropout(torch.tanh(self.fc1(output)))
# output = torch.tanh(self.fc2(output))
# output = self.fc3(output)
output=self.bottle_neck(output)
last_outputs = self.get_last_output(output, batch_seq_len)
# output = output.reshape(batch_size * seq_len, -1)
# return last_outputs, hidden
return last_outputs
def get_last_output(self, output, batch_seq_len):
last_outputs = torch.zeros((output.shape[0], output.shape[2]))
for i in range(len(batch_seq_len)):
last_outputs[i] = output[i][batch_seq_len[i] - 1] # index 是长度 -1
last_outputs = last_outputs.to(output.device)
return last_outputs
class feat_bootleneck(nn.Module):
def __init__(self, feature_dim, bottleneck_dim=256, type="bn"):
super(feat_bootleneck, self).__init__()
self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=0.3)
self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
self.bottleneck.apply(init_weights)
self.type = type
def forward(self, x):
x = self.bottleneck(x)
if self.type == "bn":
x = self.bn(x)
x = self.dropout(x)
return x
class VGGBase(nn.Module):
def __init__(self, pret=True, no_pool=False,bootleneck_dim=256):
super(VGGBase, self).__init__()
vgg16 = models.vgg16(pretrained=pret)
self.classifier = nn.Sequential(*list(vgg16.classifier.
_modules.values())[:-1])
# self.classifier = nn.Sequential(*list(vgg16.classifier.
# _modules.values())[0:3])
self.features = nn.Sequential(*list(vgg16.features.
_modules.values())[:])
self.s = nn.Parameter(torch.FloatTensor([10]))
self.bottle_neck=feat_bootleneck(feature_dim=4096,bottleneck_dim=bootleneck_dim)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 7 * 7 * 512)
x = self.classifier(x)
x = self.bottle_neck(x)
return x
# def get_classifier_pra(self):
# return self.features.parameters()
class VGGBase_no_neck(nn.Module):
def __init__(self, pret=True, no_pool=False,bootleneck_dim=256):
super(VGGBase_no_neck, self).__init__()
vgg16 = models.vgg16(pretrained=pret)
self.classifier = nn.Sequential(*list(vgg16.classifier.
_modules.values())[:-1])
# self.classifier = nn.Sequential(*list(vgg16.classifier.
# _modules.values())[0:3])
self.features = nn.Sequential(*list(vgg16.features.
_modules.values())[:])
self.s = nn.Parameter(torch.FloatTensor([10]))
# self.bottle_neck=feat_bootleneck(feature_dim=4096,bottleneck_dim=bootleneck_dim)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 7 * 7 * 512)
x = self.classifier(x)
# x = self.bottle_neck(x)
return x
momentum = 0.001
def mish(x):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)"""
return x * torch.tanh(F.softplus(x))
class PSBatchNorm2d(nn.BatchNorm2d):
"""How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)"""
def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True):
super().__init__(num_features, eps, momentum, affine, track_running_stats)
self.alpha = alpha
def forward(self, x):
return super().forward(x) + self.alpha
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001, eps=0.001)
self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=True)
self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001, eps=0.001)
self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=True)
self.drop_rate = drop_rate
self.equalInOut = (in_planes == out_planes)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=True) or None
self.activate_before_residual = activate_before_residual
def forward(self, x):
if not self.equalInOut and self.activate_before_residual == True:
x = self.relu1(self.bn1(x))
else:
print(x.shape)
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.drop_rate > 0:
out = F.dropout(out, p=self.drop_rate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class NetworkBlock(nn.Module):
def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(
block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual)
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual):
layers = []
for i in range(int(nb_layers)):
layers.append(block(i == 0 and in_planes or out_planes, out_planes,
i == 0 and stride or 1, drop_rate, activate_before_residual))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNetVar(nn.Module):
def __init__(self, first_stride, num_classes, depth=28, widen_factor=2, drop_rate=0.0, is_remix=False):
super(WideResNetVar, self).__init__()
channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor, 128 * widen_factor]
assert ((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1,
padding=1, bias=True)
# 1st block
self.block1 = NetworkBlock(
n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True)
# 2nd block
self.block2 = NetworkBlock(
n, channels[1], channels[2], block, 2, drop_rate)
# 3rd block
self.block3 = NetworkBlock(
n, channels[2], channels[3], block, 2, drop_rate)
# 4th block
self.block4 = NetworkBlock(
n, channels[3], channels[4], block, 2, drop_rate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(channels[4], momentum=0.001, eps=0.001)
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False)
self.features = nn.Sequential(self.conv1,self.block2,self.block3,self.block4,self.bn1,self.relu)
self.bottle_neck=feat_bootleneck(feature_dim=channels[4], bottleneck_dim=256)
# self.fc = nn.Linear(, num_classes)
self.channels = channels[4]
# rot_classifier for Remix Match
self.is_remix = is_remix
if is_remix:
self.rot_classifier = nn.Linear(self.channels, 4)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
m.bias.data.zero_()
def forward(self, x, ood_test=False):
# out = self.conv1(x)
# out = self.block1(out)
# out = self.block2(out)
# out = self.block3(out)
# out = self.block4(out)
# out = self.relu(self.bn1(out))
out=self.features(x)
out = F.adaptive_avg_pool2d(out, 1)
out = out.view(-1, self.channels)
output = self.bottle_neck(out)
if ood_test:
return output, out
else:
if self.is_remix:
rot_output = self.rot_classifier(out)
return output, rot_output
else:
return output
class build_WideResNetVar:
def __init__(self, first_stride=1, depth=28, widen_factor=2, bn_momentum=0.01, leaky_slope=0.1, dropRate=0.0,
use_embed=False, is_remix=False):
self.first_stride = first_stride
self.depth = depth
self.widen_factor = widen_factor
self.bn_momentum = bn_momentum
self.dropRate = dropRate
self.leaky_slope = leaky_slope
self.use_embed = use_embed
self.is_remix = is_remix
def build(self, num_classes):
return WideResNetVar(
first_stride=self.first_stride,
depth=self.depth,
num_classes=num_classes,
widen_factor=self.widen_factor,
drop_rate=self.dropRate,
is_remix=self.is_remix,
)
import torch.nn.utils.weight_norm as weightNorm
class Predictor(nn.Module):
def __init__(self, num_class=64, inc=4096, temp=0.05,norm_feature=1):
super(Predictor, self).__init__()
self.fc = nn.Linear(inc, num_class, bias=True)
# self.fc = weightNorm(nn.Linear(inc, num_class,bias=True),name="weight") # shot
# nn.init.xavier_normal_(self.weight)
self.fc.apply(init_weights)
self.num_class = num_class
self.temp = temp
self.norm_feature=norm_feature
def forward(self, x, reverse=False, eta=0.1):
if reverse:
x = grad_reverse(x, eta)
if self.norm_feature:
x = F.normalize(x)
x_out = self.fc(x) / self.temp
else:
x_out = self.fc(x)
return x_out
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
class Predictor_deep(nn.Module):
def __init__(self, num_class=64, inc=4096, norm_feature=1,temp=0.05):
super(Predictor_deep, self).__init__()
self.fc1 = nn.Linear(inc, inc//2)
# self.fc3 = nn.Linear(inc//2, inc//2)
self.fc1.apply(init_weights)
# self.fc3.apply(init_weights)
self.fc2 = nn.Linear(inc//2, num_class,bias=False)
nn.init.xavier_normal_(self.fc2.weight)
self.bn = nn.BatchNorm1d(inc//2, affine=True)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=0.5)
# self.bn3 = nn.BatchNorm1d(inc // 2, affine=True)
# self.relu3 = nn.ReLU(inplace=True)
# self.dropout3 = nn.Dropout(p=0.3)
self.num_class = num_class
self.temp = temp
self.norm_feature=norm_feature
def forward(self, x, reverse=False, eta=0.1):
x = self.dropout(self.relu(self.bn(self.fc1(x))))
# x = self.dropout3(self.relu3(self.fc3(x)))
# if reverse:
# x = grad_reverse(x, eta)
# x = F.normalize(x)
# x_out = self.fc2(x) / self.temp
if self.norm_feature:
x = F.normalize(x)
x_out = self.fc2(x) / self.temp
else:
x_out = self.fc2(x)
return x_out
momentum = 0.001
def mish(x):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)"""
return x * torch.tanh(F.softplus(x))
class PSBatchNorm2d(nn.BatchNorm2d):
"""How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)"""
def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True):
super().__init__(num_features, eps, momentum, affine, track_running_stats)
self.alpha = alpha
def forward(self, x):
return super().forward(x) + self.alpha
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001, eps=0.001)
self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=True)
self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001, eps=0.001)
self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=True)
self.drop_rate = drop_rate
self.equalInOut = (in_planes == out_planes)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=True) or None
self.activate_before_residual = activate_before_residual
def forward(self, x):
if not self.equalInOut and self.activate_before_residual == True:
x = self.relu1(self.bn1(x))
else:
# print(x.shape)
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.drop_rate > 0:
out = F.dropout(out, p=self.drop_rate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class NetworkBlock(nn.Module):
def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(
block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual)
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual):
layers = []
for i in range(int(nb_layers)):
layers.append(block(i == 0 and in_planes or out_planes, out_planes,
i == 0 and stride or 1, drop_rate, activate_before_residual))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class WideResNet(nn.Module):
def __init__(self, first_stride, num_classes, depth=28, widen_factor=2, drop_rate=0.0, is_remix=False):
super(WideResNet, self).__init__()
channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
assert ((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
# print("channels",channels)
self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1,
padding=1, bias=True)
# 1st block
self.block1 = NetworkBlock(
n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True)
# 2nd block
self.block2 = NetworkBlock(
n, channels[1], channels[2], block, 2, drop_rate)
# 3rd block
self.block3 = NetworkBlock(
n, channels[2], channels[3], block, 2, drop_rate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001, eps=0.001)
self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False)
self.features = nn.Sequential(self.conv1,self.block1, self.block2,self.block3,self.bn1,self.relu)
# self.bottle_neck=feat_bootleneck(feature_dim=channels[3], bottleneck_dim=256)
# self.fc = nn.Linear(channels[3], num_classes)
self.channels = channels[3]
# rot_classifier for Remix Match
self.is_remix = is_remix
if is_remix:
self.rot_classifier = nn.Linear(self.channels, 4)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
m.bias.data.zero_()
def forward(self, x, ood_test=False):
# out = self.conv1(x)
# out = self.block1(out)
# out = self.block2(out)
# out = self.block3(out)
# out = self.relu(self.bn1(out))
# print("kllllll")
out=self.features(x)
out = F.adaptive_avg_pool2d(out, 1)
output = out.view(-1, self.channels)
# print(self.channels)
# output = self.fc(out)
if ood_test:
return output, out
else:
if self.is_remix:
rot_output = self.rot_classifier(out)
return output, rot_output
else:
return output
class build_WideResNet:
def __init__(self, first_stride=1, depth=28, widen_factor=2, bn_momentum=0.01, leaky_slope=0.1, dropRate=0.0,
use_embed=False, is_remix=False):
self.first_stride = first_stride
self.depth = depth
self.widen_factor = widen_factor
self.bn_momentum = bn_momentum
self.dropRate = dropRate
self.leaky_slope = leaky_slope
self.use_embed = use_embed
self.is_remix = is_remix
def build(self, num_classes):
return WideResNet(
first_stride=self.first_stride,
depth=self.depth,
num_classes=num_classes,
widen_factor=self.widen_factor,
drop_rate=self.dropRate,
is_remix=self.is_remix,
)
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torch.autograd import Function
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth',
'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
'resnet101':
'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
'resnet152':
'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
}
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or \
classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class GradReverse(Function):
def __init__(self, lambd):
self.lambd = lambd
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output * -self.lambd)
def grad_reverse(x, lambd=1.0):
return GradReverse(lambd)(x)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, nobn=False):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.nobn = nobn
def forward(self, x, source=True):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ScaleLayer(nn.Module):
def __init__(self, init_value=1e-3):
super(ScaleLayer, self).__init__()
self.scale = nn.Parameter(torch.FloatTensor([init_value]))
def forward(self, input):
print(self.scale)
return input * self.scale
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, nobn=False):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1,
stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.nobn = nobn
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class feat_bootleneck(nn.Module):
def __init__(self, feature_dim, bottleneck_dim=256, type="bn"):
super(feat_bootleneck, self).__init__()
self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=0.1)
self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
self.bottleneck.apply(init_weights)
self.type = type
def forward(self, x):
x = self.bottleneck(x)
if self.type == "bn":
x = self.bn(x)
x = self.dropout(x)
return x
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000,bottleneck=256):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.in1 = nn.InstanceNorm2d(64)
self.in2 = nn.InstanceNorm2d(128)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2,
padding=0, ceil_mode=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.features=nn.Sequential(
self.conv1 ,
self.bn1 ,
self.in1 ,
self.in2 ,
self.relu ,
self.maxpool ,
self.layer1 ,
self.layer2,
self.layer3 ,
self.layer4 ,
self.avgpool,
)
self.bottle_neck=feat_bootleneck(512 * block.expansion, bottleneck)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, nobn=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, nobn=nobn))
return nn.Sequential(*layers)
def forward(self, x):
# x = self.conv1(x)
# x = self.bn1(x)
# x = self.relu(x)
# x = self.maxpool(x)
# x = self.layer1(x)
# x = self.layer2(x)
# x = self.layer3(x)
# x = self.layer4(x)
# x = self.avgpool(x)
x=self.features(x)
x = x.view(x.size(0), -1)
x=self.bottle_neck(x)
return x
def resnet18(pretrained=True):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2])
if pretrained:
pretrained_dict = model_zoo.load_url(model_urls['resnet18'])
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
return model
def resnet34(pretrained=True,args=None):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3],bottleneck=args.bottleneck)
if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
pretrained_dict = model_zoo.load_url(model_urls['resnet34'])
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
return model
def resnet50(pretrained=True,args=None):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3],bottleneck=args.bottleneck)
if pretrained:
pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
def resnet101(pretrained=False):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3])
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3])
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
"""Meta-gradient compatible load_state_dict method"""
# pylint: disable=protected-access
# pylint: disable=redefined-builtin
from collections import OrderedDict
def _load_from_par_dict(module, par_dict, prefix):
"""Replace the module's _parameter dict with par_dict."""
_new_parameters = OrderedDict()
for name, param in module._parameters.items():
key = prefix + name
if key in par_dict:
input_param = par_dict[key]
else:
input_param = param
#edited by Ning Ma, 2020.05.24
# if input_param.shape != param.shape:
# # local shape should match the one in checkpoint
# raise ValueError(
# 'size mismatch for {}: copying a param of {} from checkpoint, '
# 'where the shape is {} in current model.'.format(
# key, param.shape, input_param.shape))
_new_parameters[name] = input_param
module._parameters = _new_parameters
def load_state_dict(module, state_dict):
r"""Replaces parameters and buffers.
Replaces parameters and buffers from :attr:`state_dict` into
the given module and its descendants. In contrast to the module's
method, this function will *not* do in-place copy of underlying data on
*parameters*, but instead replace the ``_parameter`` dict in each
module and its descendants. This allows us to backpropr through previous
gradient steps using the standard top-level API.
.. note::
You must store the original state dict (with keep_vars=True) separately
and, when ready to update them, use :funct:`load_state_dict` to return
as the module's parameters.
Arguments:
module (torch.nn.Module): a module instance whose state to update.
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
par_names = [n for n, _ in module.named_parameters()]
par_dict = OrderedDict({k: v for k, v in state_dict.items()
if k in par_names})
no_par_dict = OrderedDict({k: v for k, v in state_dict.items()
if k not in par_names})
excess = [k for k in state_dict.keys()
if k not in list(no_par_dict.keys()) + list(par_dict.keys())]
if excess:
raise ValueError(
"State variables %r not in the module's state dict %r" % (
excess, par_names))
metadata = getattr(state_dict, '_metadata', None)
if metadata is not None:
par_dict._metadata = metadata
no_par_dict._metadata = metadata
module.load_state_dict(no_par_dict, strict=False)
def load(module, prefix=''): # pylint: disable=missing-docstring
_load_from_par_dict(module, par_dict, prefix)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class Linear_fw(nn.Linear): # used in MAML to forward input with fast weight
def __init__(self, in_features, out_features):
super(Linear_fw, self).__init__(in_features, out_features)
self.in_features=in_features
self.out_features=out_features
self.weight.fast = None # Lazy hack to add fast weight link
self.bias.fast = None
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
weights = Parameter(self.weight.data.clone(), self.weight.requires_grad)
bias=Parameter(self.bias.data.clone(),self.bias.requires_grad)
weights.fast=None
bias.fast=None
result=type(self)(self.in_features,self.out_features)
result.weight=weights
result.bias=bias
memo[id(self)] = result
return result
def forward(self, x):
if self.weight.fast is not None and self.bias.fast is not None:
out = F.linear(x, self.weight.fast,
self.bias.fast) # weight.fast (fast weight) is the temporaily adapted weight
else:
out = super(Linear_fw, self).forward(x)
return out
class Conv2d_fw(nn.Conv2d): # used in MAML to forward input with fast weight
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
bias=bias)
self.weight.fast = None
if not self.bias is None:
self.bias.fast = None
self.in_channels=in_channels
self.out_channels=out_channels
self.kernel_size=kernel_size
self.stride=stride
self.padding=padding
self.curent_bias=bias
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
weights = Parameter(self.weight.data.clone(), self.weight.requires_grad)
bias = Parameter(self.bias.data.clone(), self.bias.requires_grad)
weights.fast = None
bias.fast = None
result = type(self)(self.in_channels, self.out_channels,self.kernel_size,self.stride,self.padding,self.curent_bias)
result.weight = weights
result.bias = bias
memo[id(self)] = result
return result
def forward(self, x):
if self.bias is None:
if self.weight.fast is not None:
out = F.conv2d(x, self.weight.fast, None, stride=self.stride, padding=self.padding)
else:
out = super(Conv2d_fw, self).forward(x)
else:
if self.weight.fast is not None and self.bias.fast is not None:
out = F.conv2d(x, self.weight.fast, self.bias.fast, stride=self.stride, padding=self.padding)
else:
out = super(Conv2d_fw, self).forward(x)
return out
class BatchNorm2d_fw(nn.BatchNorm2d): # used in MAML to forward input with fast weight
def __init__(self, num_features):
super(BatchNorm2d_fw, self).__init__(num_features)
self.weight.fast = None
self.bias.fast = None
def forward(self, x):
running_mean = torch.zeros(x.data.size()[1]).cuda()
running_var = torch.ones(x.data.size()[1]).cuda()
if self.weight.fast is not None and self.bias.fast is not None:
out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training=True,
momentum=1)
# batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py
else:
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1)
return out
class BatchNorm1d_fw(nn.BatchNorm1d): # used in MAML to forward input with fast weight
def __init__(self, num_features,affin=True):
super(BatchNorm1d_fw, self).__init__(num_features,affine=affin)
self.weight.fast = None
self.bias.fast = None
def forward(self, x):
running_mean = torch.zeros(x.data.size()[1]).cuda()
running_var = torch.ones(x.data.size()[1]).cuda()
if self.weight.fast is not None and self.bias.fast is not None:
out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training=True,
momentum=1)
# batch_norm momentum hack:follow hack of Kate Rakelly in pytor ch-maml/src/layers.py
else:
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1)
return out
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch import optim
import numpy as np
from scipy.spatial.distance import cdist
from loss import CrossEntropyLabelSmooth
from tqdm import tqdm
from loss import Entropy
from copy import deepcopy
class Meta(nn.Module):
"""
Meta Learner
"""
def __init__(self,encoder,classifier, feat_bootleneck,config):
"""
:param args:
"""
super(Meta, self).__init__()
self.update_lr = config["inner_lr"]
self.meta_lr = config["outer_lr"]
self.n_way = config["class_num"]
# self.k_spt = config["train_shot"]
# self.k_qry = config["train_query"]
# self.task_num = args.task_num
self.update_step = config["update_step"]
self.clip=config["grad_clip"]
self.update_step_test = config["step_test"]
# self.device=config["device"]
self.encoder = encoder
self.classifier=classifier
self.global_center=None
self.feat_bootleneck=feat_bootleneck
self.config=config
self.ft_lr=config["ft_lr"]
param_group = []
learning_rate = self.meta_lr
for k, v in self.encoder.named_parameters():
param_group += [{'params': v, 'lr': learning_rate}]
for k, v in self.feat_bootleneck.named_parameters():
param_group += [{'params': v, 'lr': learning_rate}]
for k, v in self.classifier.named_parameters():
param_group += [{'params': v, 'lr': learning_rate}]
self.meta_optim = optim.SGD(param_group, momentum=0.9, nesterov=True, weight_decay=config["weight_decay"])
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.meta_optim, mode='min',
factor=0.2, patience=config["patience"],
verbose=True, min_lr=1e-6)
self.task_index=1
self.update_flag=config["batch_per_episodes"]
# self.loss=nn.CrossEntropyLoss()
self.loss=CrossEntropyLabelSmooth(self.n_way,device=None)
self.cluster=False
def obtain_center(self,data, netF, netC, netB=None):
start_test = True
with torch.no_grad():
inputs = data[0]
labels = data[1]
feas = netB(netF(inputs))
outputs = netC(feas)
all_fea = feas.float().cpu()
all_output = outputs.float().cpu()
all_label = labels.float().cpu()
all_output = nn.Softmax(dim=1)(all_output)
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
all_fea = all_fea.float().cpu().numpy()
K = all_output.size(1)
aff = all_output.float().cpu().numpy()
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
dd = cdist(all_fea, initc, 'cosine')
pred_label = dd.argmin(axis=1)
acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
aff = np.eye(K)[pred_label]
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
center = torch.from_numpy(initc).cuda()
log_str = 'Accuracy = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
# args.out_file.write(log_str + '\n')
# args.out_file.flush()
# print(log_str + '\n')
return center
def obtain_label(self,features_target, center,device=None):
features_target = torch.cat((features_target, torch.ones(features_target.size(0), 1).cuda()), 1)
fea = features_target.float().detach().cpu().numpy()
center = center.float().detach().cpu().numpy()
dis = cdist(fea, center, 'cosine') + 1
pred = np.argmin(dis, axis=1)
pred = torch.from_numpy(pred).cuda()
return pred
def forward(self, support_data, query_data):
"""
:param x_spt: [b, setsz, c_, h, w]
:param y_spt: [b, setsz]
:param x_qry: [b, querysz, c_, h, w]
:param y_qry: [b, querysz]
:return:
"""
[support_image, support_label] = support_data
[query_image, query_label] = query_data
task_num = 1
querysz = query_label.size()[0]
losses_q = [0 for _ in range(self.update_step)] # losses_q[i] is the loss on step i
corrects = [0 for _ in range(self.update_step)]
for i in range(task_num):
fast_parameters = list(self.parameters()) # the first gradient calcuated in line 45 is based on original weight
for weight in self.parameters():
weight.fast = None
self.zero_grad()
for k in range(0, self.update_step):
# 1. run the i-th task and compute loss for k=1~K-1
# net_F,net_C=deepcopy(self.encoder),deepcopy(self.classifier)
# center = self.obtain_center(support_data, net_F,net_C)
if self.cluster is True:
center = self.obtain_center(support_data, self.encoder,netC=self.classifier,netB=self.feat_bootleneck)
with torch.no_grad():
# features_support = net_F(support_image)
features_support = self.feat_bootleneck(self.encoder(support_image))
pred = self.obtain_label(features_support, center,None)
logits= self.classifier(self.feat_bootleneck(self.encoder(support_image)))
loss = self.loss(logits,pred)
else:
logits = self.classifier(self.feat_bootleneck(self.encoder(support_image)))
loss = self.loss(logits, support_label)
# buiuld graph supld fport gradient of gradient
grad = torch.autograd.grad(loss, fast_parameters,create_graph=True)
fast_parameters = []
for index, weight in enumerate(self.parameters()):
# for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py
# if grad[k] is None:
# fast_parameters.append(weight.fast)
# continue
if weight.fast is None:
weight.fast = weight - self.update_lr * grad[index] # create weight.fast
else:
# create an updated weight.fast,
# note the '-' is not merely minus value, but to create a new weight.fast
weight.fast = weight.fast - self.update_lr * grad[index]
# gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts
fast_parameters.append(weight.fast)
logits_q = self.classifier(self.feat_bootleneck(self.encoder(query_image)))
# loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = self.loss(logits_q, query_label)
losses_q[k] += loss_q
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, query_label).sum().item() # convert to numpy
corrects[k] = corrects[k] + correct
# end of all tasks
# sum over all losses on query set across all tasks
loss_q = losses_q[-1] / task_num
# print("loss",loss_q.item())
# optimize theta parameters
# self.meta_optim.zero_grad()
loss_q.backward()
# self.meta_optim.step()
if self.task_index==self.update_flag:
if self.clip > 0.1: # 0.1 threshold wether to do clip
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)
self.meta_optim.step()
self.meta_optim.zero_grad()
self.task_index=1
else:
self.task_index=self.task_index+1
accs = 100*np.array(corrects) / (querysz * task_num)
return accs,loss_q.item()
def finetunning(self, support_data, query_data):
"""
:param x_spt: [setsz, c_, h, w]
:param y_spt: [setsz]
:param x_qry: [querysz, c_, h, w]
:param y_qry: [querysz]
:return:
"""
[support_image, support_label] = support_data
[query_image, query_label] = query_data
task_num = 1
# querysz = query_label.size()[0]
# losses_q = [0 for _ in range(self.update_step)] # losses_q[i] is the loss on step i
# corrects = [0 for _ in range(self.update_step)]
querysz = query_label.size()[0]
losses_q = [0 for _ in range(self.update_step_test)] # losses_q[i] is the loss on step i
corrects = [0 for _ in range(self.update_step_test)]
for i in range(task_num):
fast_parameters = list(
self.parameters()) # the first gradient calcuated in line 45 is based on original weight
for weight in self.parameters():
weight.fast = None
self.zero_grad()
for k in range(0, self.update_step_test):
# 1. run the i-th task and compute loss for k=1~K-1
netF=deepcopy(self.encoder)
netB=deepcopy(self.feat_bootleneck)
# net
if self.global_center is not None:
center=self.global_center
else:
data=[support_image,support_label]
center = self.obtain_center(data, netF=self.encoder, netC=self.classifier,netB=self.feat_bootleneck)
with torch.no_grad():
# features_support = net_F(support_image)
features_support = self.encoder(support_image)
pred = self.obtain_label(features_support, center,device=None)
logits = self.classifier(self.encoder(support_image))
loss = self.loss(logits, pred)
# buiuld graph supld fport gradient of gradient
grad = torch.autograd.grad(loss, fast_parameters, create_graph=True)
fast_parameters = []
for index, weight in enumerate(self.parameters()):
# for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py
# if grad[k] is None:
# fast_parameters.append(weight.fast)
# continue
if weight.fast is None:
weight.fast = weight - self.update_lr * grad[index] # create weight.fast
else:
# create an updated weight.fast,
# note the '-' is not merely minus value, but to create a new weight.fast
weight.fast = weight.fast - self.update_lr * grad[index]
# gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts
fast_parameters.append(weight.fast)
# print('add')
logits_q= self.classifier(self.encoder(query_image))
# # loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = self.loss(logits_q, query_label)
losses_q[k] += loss_q
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, query_label).sum().item() # convert to numpy
corrects[k] = corrects[k] + correct
accs = 100*np.array(corrects) / querysz*task_num
return accs,losses_q
def finetunning_shot(self, target_dataset):
# fast_parameters = list(self.parameters()) # the first gradient calcuated in line 45 is based on original weight
for weight in self.parameters():
weight.fast = None
self.zero_grad()
## set base network
netF=deepcopy(self.encoder)
netC =deepcopy( self.classifier)
netB=deepcopy(self.feat_bootleneck)
for k, v in netC.named_parameters():
v.requires_grad = False
param_group = []
for k, v in netF.named_parameters():
param_group += [{'params': v, 'lr': self.ft_lr}]
for k, v in netB.named_parameters():
param_group += [{'params': v, 'lr': self.ft_lr}]
optimizer = optim.SGD(param_group, momentum=0.9, weight_decay=5e-4, nesterov=True)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
for epoch in tqdm(range(self.config["ft_epoch"]), leave=False):
netF.train()
netB.train()
# iter_test = iter(dset_loaders["target"])
im_loss, classifier_loss = None, None
prev_F = deepcopy(netF)
prev_B = deepcopy(netB)
prev_F.eval()
prev_B.eval()
center = self.build_global_center(target_dataset,prev_F,prev_B,netC)
for _, data in tqdm(enumerate(target_dataset), leave=False):
# if inputs_test.size(0) == 1:
# continue
inputs_test = data['T'].cuda()
with torch.no_grad():
features_test = prev_B(prev_F(inputs_test))
pred = self.obtain_label(features_test, center)
features_test = netB(netF(inputs_test))
outputs_test = netC(features_test)
classifier_loss = CrossEntropyLabelSmooth(num_classes=self.config["class_num"], epsilon=0)(outputs_test, pred)
softmax_out = nn.Softmax(dim=1)(outputs_test)
im_loss = torch.mean(Entropy(softmax_out))
msoftmax = softmax_out.mean(dim=0)
im_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5))
total_loss = im_loss + 0.1 * classifier_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
netF.eval()
netB.eval()
# netC.eval()
acc, _ = self.cal_acc(target_dataset, netF, netB, netC)
log_str = 'tra-tag to: {}, Iter:{}/{}; Accuracy = {:.2f}%;t_ls: {:.4f}; mi_ls:{:.4f}: cl_loss: {:.4f}'\
.format(self.config["target"], epoch + 1,
self.config["ft_epoch"], acc * 100,total_loss.item(),im_loss.item(),classifier_loss.item())
# args.out_file.write(log_str + '\n')
# args.out_file.flush()
print(log_str + '\n')
for k, v in netC.named_parameters():
v.requires_grad = True
# torch.save(netF.state_dict(), osp.join(args.output_dir, "target_F.pt"))
# torch.save(netB.state_dict(), osp.join(args.output_dir, "target_B.pt"))
# torch.save(netC.state_dict(), osp.join(args.output_dir, "target_C.pt"))
def build_global_center(self,loader,prev_F,prev_B,netC):
netF=prev_F
netC=netC
netB=prev_B
start_test = True
with torch.no_grad():
for i, data in enumerate(loader, 1):
# data = iter_test.next()
inputs = data['T']
labels = data['T_label']
inputs = inputs.cuda()
feas = netB(netF(inputs))
outputs = netC(feas)
if start_test:
all_fea = feas.float().cpu()
all_output = outputs.float().cpu()
all_label = labels.float()
start_test = False
else:
all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_output = nn.Softmax(dim=1)(all_output)
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
all_fea = all_fea.float().cpu().numpy()
K = all_output.size(1)
aff = all_output.float().cpu().numpy()
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
dd = cdist(all_fea, initc, 'cosine')
pred_label = dd.argmin(axis=1)
acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
aff = np.eye(K)[pred_label]
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
center = torch.from_numpy(initc).cuda()
log_str = 'predict acc to cluster acc = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
# args.out_file.write(log_str + '\n')
# args.out_file.flush()
print(log_str + '\n')
# self.global_center=center
return center
def remove_global_center(self):
self.global_center=None
def adapt_meta_learning_rate(self,loss):
self.scheduler.step(loss)
def get_meta_learning_rate(self):
epoch_learning_rate=[]
for param_group in self.meta_optim.param_groups:
epoch_learning_rate.append(param_group['lr'])
return epoch_learning_rate[0]
def cal_acc(self,loader, netF, netB, netC):
start_test = True
with torch.no_grad():
# iter_test = iter(loader)
for i, data in enumerate(loader, 1):
# data = iter_test.next()
inputs = data['T']
labels = data['T_label']
inputs = inputs.cuda()
outputs = netC(netB(netF(inputs)))
if start_test:
all_output = outputs.float().cpu()
all_label = labels.float()
start_test = False
else:
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
mean_ent = torch.mean(Entropy(nn.Softmax(dim=1)(all_output))).cpu().data.item()
return accuracy, mean_ent
if __name__ == '__main__':
pass
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable
import math
import pdb
import torch.nn.utils.weight_norm as weightNorm
from collections import OrderedDict
import torch.nn.functional as F
from models.layersFw import Conv2d_fw,BatchNorm2d_fw,Linear_fw,BatchNorm1d_fw
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
class feat_bootleneck(nn.Module):
def __init__(self, feature_dim, bottleneck_dim=256, type="bn"):
super(feat_bootleneck, self).__init__()
# self.bn = BatchNorm1d_fw(bottleneck_dim, affin=True)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=0.5)
self.bottleneck = Linear_fw(feature_dim, bottleneck_dim)
self.bottleneck.apply(init_weights)
self.type = type
def forward(self, x):
# print(x.shape)
x = self.bottleneck(x)
if self.type == "bn":
x = self.bn(x)
x = self.dropout(x)
return x
class feat_classifier(nn.Module):
def __init__(self, class_num=10,input_dim=256*4*4, bottleneck_dim=256, type="linear"):
super(feat_classifier, self).__init__()
# if type == "linear":
# self.fc = Linear_fw(input_dim, class_num)
self.hidden = 256
self.class_num = class_num
self.in_features = input_dim
self.lin1 = Linear_fw(bottleneck_dim, self.hidden//2)
# self.lin2 = Linear_fw(self.hidden, self.hidden // 2)
self.lin3 = Linear_fw(self.hidden // 2, self.class_num)
# else:
# self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
self.lin1.apply(init_weights)
# self.lin2.apply(init_weights)
self.lin3.apply(init_weights)
self.relu = F.leaky_relu
# self.bn1 = torch.nn.BatchNorm1d(self.nhid,affine=False)
# self.bn2 = torch.nn.BatchNorm1d(self.nhid // 2,affine=False)
# self.bn3 = torch.nn.BatchNorm1d(self.num_classes,affine=False)
def forward(self, x):
x = self.relu(self.lin1(x), negative_slope=0.1)
x=F.dropout(x,0.5)
# x = self.relu(self.lin2(x), negative_slope=0.1)
# x = F.dropout(x, 0.3)
x = self.relu(self.lin3(x), negative_slope=0.1)
# x = F.log_softmax(x, dim=-1)
return x
class DTNBase(nn.Module):
def __init__(self):
super(DTNBase, self).__init__()
self.conv_params = nn.Sequential(
Conv2d_fw(3, 64, kernel_size=5, stride=2, padding=2),
# BatchNorm2d_fw(64),
nn.Dropout2d(0.5),
nn.ReLU(),
Conv2d_fw(64, 128, kernel_size=5, stride=2, padding=2),
# BatchNorm2d_fw(128),
nn.Dropout2d(0.5),
nn.ReLU(),
Conv2d_fw(128, 256, kernel_size=5, stride=2, padding=2),
# BatchNorm2d_fw(256),
nn.Dropout2d(0.5),
nn.ReLU()
)
self.in_features = 256*4*4
def forward(self, x):
x = self.conv_params(x)
x = x.view(x.size(0), -1)
return x
class LeNetBase(nn.Module):
def __init__(self):
super(LeNetBase, self).__init__()
self.conv_params = nn.Sequential(
Conv2d_fw(3, 20, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
Conv2d_fw(20, 50, kernel_size=5),
nn.Dropout2d(p=0.5),
nn.MaxPool2d(2),
nn.ReLU(),
)
# self.in_features = 50*4*4
self.in_features = 1250
def forward(self, x):
x = self.conv_params(x)
x = x.view(x.size(0), -1)
return x
import torchvision
from torchvision import models
import math
import pdb
def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
resnet_dict = {"ResNet18": models.resnet18, "ResNet34": models.resnet34, "ResNet50": models.resnet50,
"ResNet101": models.resnet101, "ResNet152": models.resnet152}
def grl_hook(coeff):
def fun1(grad):
return -coeff * grad.clone()
return fun1
class ResNetFc(nn.Module):
def __init__(self, resnet_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
super(ResNetFc, self).__init__()
model_resnet = resnet_dict[resnet_name](pretrained=True)
self.conv1 = model_resnet.conv1
self.bn1 = model_resnet.bn1
self.relu = model_resnet.relu
self.maxpool = model_resnet.maxpool
self.layer1 = model_resnet.layer1
self.layer2 = model_resnet.layer2
self.layer3 = model_resnet.layer3
self.layer4 = model_resnet.layer4
self.avgpool = model_resnet.avgpool
self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, \
self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool)
self.use_bottleneck = use_bottleneck
self.new_cls = new_cls
if new_cls:
if self.use_bottleneck:
self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim)
self.fc = nn.Linear(bottleneck_dim, class_num)
self.bottleneck.apply(init_weights)
self.fc.apply(init_weights)
self.__in_features = bottleneck_dim
else:
self.fc = nn.Linear(model_resnet.fc.in_features, class_num)
self.fc.apply(init_weights)
self.__in_features = model_resnet.fc.in_features
else:
self.fc = model_resnet.fc
self.__in_features = model_resnet.fc.in_features
def forward(self, x):
x = self.feature_layers(x)
x = x.view(x.size(0), -1)
if self.use_bottleneck and self.new_cls:
x = self.bottleneck(x)
# y = self.fc(x)
return x
def output_num(self):
return self.__in_features
def get_parameters(self):
if self.new_cls:
if self.use_bottleneck:
parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
{"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
{"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
else:
parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
{"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
else:
parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
return parameter_list
class RandomLayer(nn.Module):
def __init__(self, input_dim_list=[], output_dim=1024):
super(RandomLayer, self).__init__()
self.input_num = len(input_dim_list)
self.output_dim = output_dim
self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)]
def forward(self, input_list):
return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]
return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list))
for single in return_list[1:]:
return_tensor = torch.mul(return_tensor, single)
return return_tensor
def cuda(self):
super(RandomLayer, self).cuda()
self.random_matrix = [val.cuda() for val in self.random_matrix]
# class LRN(nn.Module):
# def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True):
# super(LRN, self).__init__()
# self.ACROSS_CHANNELS = ACROSS_CHANNELS
# if ACROSS_CHANNELS:
# self.average = nn.AvgPool3d(kernel_size=(local_size, 1, 1),
# stride=1,
# padding=(int((local_size - 1.0) / 2), 0, 0))
# else:
# self.average = nn.AvgPool2d(kernel_size=local_size,
# stride=1,
# padding=int((local_size - 1.0) / 2))
# self.alpha = alpha
# self.beta = beta
#
# def forward(self, x):
# if self.ACROSS_CHANNELS:
# div = x.pow(2).unsqueeze(1)
# div = self.average(div).squeeze(1)
# div = div.mul(self.alpha).add(1.0).pow(self.beta)
# else:
# div = x.pow(2)
# div = self.average(div)
# div = div.mul(self.alpha).add(1.0).pow(self.beta)
# x = x.div(div)
# return x
#
#
# class AlexNet(nn.Module):
#
# def __init__(self, num_classes=1000):
# super(AlexNet, self).__init__()
# self.features = nn.Sequential(
# nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
# nn.ReLU(inplace=True),
# LRN(local_size=5, alpha=0.0001, beta=0.75),
# nn.MaxPool2d(kernel_size=3, stride=2),
# nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2),
# nn.ReLU(inplace=True),
# LRN(local_size=5, alpha=0.0001, beta=0.75),
# nn.MaxPool2d(kernel_size=3, stride=2),
# nn.Conv2d(256, 384, kernel_size=3, padding=1),
# nn.ReLU(inplace=True),
# nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2),
# nn.ReLU(inplace=True),
# nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2),
# nn.ReLU(inplace=True),
# nn.MaxPool2d(kernel_size=3, stride=2),
# )
# self.classifier = nn.Sequential(
# nn.Linear(256 * 6 * 6, 4096),
# nn.ReLU(inplace=True),
# nn.Dropout(),
# nn.Linear(4096, 4096),
# nn.ReLU(inplace=True),
# nn.Dropout(),
# nn.Linear(4096, num_classes),
# )
#
# def forward(self, x):
# x = self.features(x)
# print(x.size())
# x = x.view(x.size(0), 256 * 6 * 6)
# x = self.classifier(x)
# return x
#
#
# def alexnet(pretrained=False, **kwargs):
# r"""AlexNet model architecture from the
# `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
# Args:
# pretrained (bool): If True, returns a model pre-trained on ImageNet
# """
# model = AlexNet(**kwargs)
# if pretrained:
# model_path = './alexnet.pth.tar'
# pretrained_model = torch.load(model_path)
# model.load_state_dict(pretrained_model['state_dict'])
# return model
#
#
# # convnet without the last layer
# class AlexNetFc(nn.Module):
# def __init__(self, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
# super(AlexNetFc, self).__init__()
# model_alexnet = alexnet(pretrained=True)
# self.features = model_alexnet.features
# self.classifier = nn.Sequential()
# for i in range(6):
# self.classifier.add_module("classifier" + str(i), model_alexnet.classifier[i])
# self.feature_layers = nn.Sequential(self.features, self.classifier)
#
# self.use_bottleneck = use_bottleneck
# self.new_cls = new_cls
# if new_cls:
# if self.use_bottleneck:
# self.bottleneck = nn.Linear(4096, bottleneck_dim)
# self.fc = nn.Linear(bottleneck_dim, class_num)
# self.bottleneck.apply(init_weights)
# self.fc.apply(init_weights)
# self.__in_features = bottleneck_dim
# else:
# self.fc = nn.Linear(4096, class_num)
# self.fc.apply(init_weights)
# self.__in_features = 4096
# else:
# self.fc = model_alexnet.classifier[6]
# self.__in_features = 4096
#
# def forward(self, x):
# x = self.features(x)
# x = x.view(x.size(0), -1)
# x = self.classifier(x)
# if self.use_bottleneck and self.new_cls:
# x = self.bottleneck(x)
# y = self.fc(x)
# return x, y
#
# def output_num(self):
# return self.__in_features
#
# def get_parameters(self):
# if self.new_cls:
# if self.use_bottleneck:
# parameter_list = [{"params": self.features.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
# else:
# parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
# else:
# parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
# return parameter_list
#
#
#
# vgg_dict = {"VGG11": models.vgg11, "VGG13": models.vgg13, "VGG16": models.vgg16, "VGG19": models.vgg19,
# "VGG11BN": models.vgg11_bn, "VGG13BN": models.vgg13_bn, "VGG16BN": models.vgg16_bn,
# "VGG19BN": models.vgg19_bn}
#
#
# class VGGFc(nn.Module):
# def __init__(self, vgg_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
# super(VGGFc, self).__init__()
# model_vgg = vgg_dict[vgg_name](pretrained=True)
# self.features = model_vgg.features
# self.classifier = nn.Sequential()
# for i in range(6):
# self.classifier.add_module("classifier" + str(i), model_vgg.classifier[i])
# self.feature_layers = nn.Sequential(self.features, self.classifier)
#
# self.use_bottleneck = use_bottleneck
# self.new_cls = new_cls
# if new_cls:
# if self.use_bottleneck:
# self.bottleneck = nn.Linear(4096, bottleneck_dim)
# self.fc = nn.Linear(bottleneck_dim, class_num)
# self.bottleneck.apply(init_weights)
# self.fc.apply(init_weights)
# self.__in_features = bottleneck_dim
# else:
# self.fc = nn.Linear(4096, class_num)
# self.fc.apply(init_weights)
# self.__in_features = 4096
# else:
# self.fc = model_vgg.classifier[6]
# self.__in_features = 4096
#
# def forward(self, x):
# x = self.features(x)
# x = x.view(x.size(0), -1)
# x = self.classifier(x)
# if self.use_bottleneck and self.new_cls:
# x = self.bottleneck(x)
# y = self.fc(x)
# return x, y
#
# def output_num(self):
# return self.__in_features
#
# def get_parameters(self):
# if self.new_cls:
# if self.use_bottleneck:
# parameter_list = [{"params": self.features.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
# else:
# parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
# {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
# else:
# parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
# return parameter_list
#
#
# # For SVHN dataset
# class DTN(nn.Module):
# def __init__(self):
# super(DTN, self).__init__()
# self.conv_params = nn.Sequential(
# nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
# nn.BatchNorm2d(64),
# nn.Dropout2d(0.1),
# nn.ReLU(),
# nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
# nn.BatchNorm2d(128),
# nn.Dropout2d(0.3),
# nn.ReLU(),
# nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
# nn.BatchNorm2d(256),
# nn.Dropout2d(0.5),
# nn.ReLU()
# )
#
# self.fc_params = nn.Sequential(
# nn.Linear(256 * 4 * 4, 512),
# nn.BatchNorm1d(512),
# nn.ReLU(),
# nn.Dropout()
# )
#
# self.classifier = nn.Linear(512, 10)
# self.__in_features = 512
#
# def forward(self, x):
# x = self.conv_params(x)
# x = x.view(x.size(0), -1)
# x = self.fc_params(x)
# y = self.classifier(x)
# return x, y
#
# def output_num(self):
# return self.__in_features
#
#
# class LeNet(nn.Module):
# def __init__(self):
# super(LeNet, self).__init__()
# self.conv_params = nn.Sequential(
# nn.Conv2d(1, 20, kernel_size=5),
# nn.MaxPool2d(2),
# nn.ReLU(),
# nn.Conv2d(20, 50, kernel_size=5),
# nn.Dropout2d(p=0.5),
# nn.MaxPool2d(2),
# nn.ReLU(),
# )
#
# self.fc_params = nn.Sequential(nn.Linear(50 * 4 * 4, 500), nn.ReLU(), nn.Dropout(p=0.5))
# self.classifier = nn.Linear(500, 10)
# self.__in_features = 500
#
# def forward(self, x):
# x = self.conv_params(x)
# x = x.view(x.size(0), -1)
# x = self.fc_params(x)
# y = self.classifier(x)
# return x, y
#
# def output_num(self):
# return self.__in_features
# class AdversarialNetwork(nn.Module):
# def __init__(self, in_feature, hidden_size):
# super(AdversarialNetwork, self).__init__()
# self.ad_layer1 = nn.Linear(in_feature, hidden_size)
# self.ad_layer2 = nn.Linear(hidden_size, hidden_size)
# self.ad_layer3 = nn.Linear(hidden_size, 1)
# self.relu1 = nn.ReLU()
# self.relu2 = nn.ReLU()
# self.dropout1 = nn.Dropout(0.5)
# self.dropout2 = nn.Dropout(0.5)
# self.sigmoid = nn.Sigmoid()
# self.apply(init_weights)
# self.iter_num = 0
# self.alpha = 10
# self.low = 0.0
# self.high = 1.0
# self.max_iter = 10000.0
#
# def forward(self, x):
# if self.training:
# self.iter_num += 1
# coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter)
# x = x * 1.0
# x.register_hook(grl_hook(coeff))
# x = self.ad_layer1(x)
# x = self.relu1(x)
# x = self.dropout1(x)
# x = self.ad_layer2(x)
# x = self.relu2(x)
# x = self.dropout2(x)
# y = self.ad_layer3(x)
# y = self.sigmoid(y)
# return y
#
# def output_num(self):
# return 1
#
# def get_parameters(self):
# return [{"params": self.parameters(), "lr_mult": 10, 'decay_mult': 2}]
"""Modified PyTorch Optimizer for retaining computation graphs.
Author: Sebastian Flennerhag
The general rule for porting a PyTorch optimizer is
1. Inherit the PyTorch optimizer
2. Override the ``step`` method by
1. Add a ``retain_graph`` argument
3. Rewrite step to
1. Remove any in-place operations
2. use a clone of the gradient for scaling factors that involve
the gradient
3. Return a list of all parameters
Expected behavior: if ``retain_graph=False``, revert to default behavior.
Else, use overriden method.
The overridden method will create a computational graph and return this
in the ``new_parameters`` list created through (c). Note that the
*optimizer* still keeps the original node, so after taking a ``step``,
it is necessary to replace the ``_parameters`` dict underlying
the model (assuming it inherits ``nn.Module``).
"""
import math
import torch
from torch import optim
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=redefined-builtin
# pylint: disable=protected-access
# pylint: disable=too-many-locals
# pylint: disable=arguments-differ
# pylint: disable=missing-docstring
# pylint: disable=too-few-public-methods
class SGD(optim.SGD):
def __init__(self, *args, detach=False, **kwargs):
self.detach = detach
super(SGD, self).__init__(*args, **kwargs)
def step(self, closure=None, retain_graph=False):
"""Performs an optimization step and retain the computational graph.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
retain_graph (bool): whether to allow backprop through optimizer.
"""
if not retain_graph:
return super(SGD, self).step(closure)
loss = None
if closure is not None:
loss = closure()
new_params = []
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
new_pg = []
for p in group['params']:
if p.grad is None:
new_params.append(p)
continue
d_p = p.grad if not self.detach else p.grad.detach()
if weight_decay != 0:
d_p = d_p.add(weight_decay, p)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = \
torch.zeros_like(p.data)
buf = buf.mul(momentum).add(d_p)
else:
buf = param_state['momentum_buffer'].to(p.device)
buf = buf.mul(momentum).add(1 - dampening, d_p)
if nesterov:
d_p = d_p + momentum * buf
else:
d_p = buf
p = p - group['lr'] * d_p
p.retain_grad()
new_params.append(p)
new_pg.append(p)
group['params'] = new_pg
return loss, new_params
class Adam(optim.Adam):
def __init__(self, *args, detach=False,
detach_first_moment=False,
detach_second_moment=False, **kwargs):
self.detach = detach
self.detach_first_moment = detach_first_moment
self.detach_second_moment = detach_second_moment
super(Adam, self).__init__(*args, **kwargs)
def step(self, closure=None, retain_graph=False):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
retain_graph (bool): whether to allow backprop through optimizer.
"""
if not retain_graph:
return super(Adam, self).step(closure)
loss = None
if closure is not None:
loss = closure()
new_params = []
for group in self.param_groups:
new_pg = []
for p in group['params']:
if p.grad is None:
new_params.append(p)
new_pg.append(p)
continue
grad = p.grad if not self.detach else p.grad.detach()
if grad.is_sparse:
raise RuntimeError(
'Adam does not support sparse gradients.')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad.
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p)
# Decay the first and second moment running average coefficient
g1 = grad.detach() if self.detach_first_moment else grad
g2 = grad.detach() if self.detach_second_moment else grad
exp_avg = exp_avg.mul(beta1).add(1 - beta1, g1)
exp_avg_sq = exp_avg_sq.mul(beta2).addcmul(1 - beta2, g2, g2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg.
max_exp_avg_sq = torch.max(max_exp_avg_sq, exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add(group['eps'])
else:
denom = exp_avg_sq.sqrt().add(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2)
step_size /= bias_correction1
p = p.addcdiv(-step_size, exp_avg, denom)
p.retain_grad()
new_params.append(p)
new_pg.append(p)
group['params'] = new_pg
return loss, new_params
"""Meta learner objectives for WarpGrad.
Updaters are classes that manage how meta-updates should be performed. The
main `DualUpdater` allows for dual updates to warp parameters and the
initialization.
:author: Sebastian Flennerhag.
"""
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=redefined-builtin
# pylint: disable=protected-access
# pylint: disable=too-many-locals
# pylint: disable=arguments-differ
import numpy as np
import joblib
import torch
from .optim import SGD
from models.utils_warp import (step, approx_step, unfreeze, freeze,
get_data, acc_fn, global_norm, backward,
state_dict_to_par_list)
class DualUpdater:
"""Implements the WarpGrad meta-objective.
This updater applies the WarpGrad meta-objective to warp-parameters and
if specified, to the initialization using the given meta-loss for the
initialization.
"""
def __init__(self, criterion, warp_objective=1, init_objective=1,
epochs=1, bsz=1, norm=True, approx=False):
"""Initialize an updater.
Arguments:
criterion (function): task loss criterion.
warp_objective (int): type of WarpGrad objective.
init_objective (int): type of objective for initialization
(optional).
epochs (int): number of times to iterate over buffer (default=1).
bsz (int): task parameter batch size between updates (default=1).
norm (bool): use the norm in the Leap objective (d1)
(default=True).
approx (bool): use approximate (Hessian-free) meta-objective.
"""
self.warp_objective = warp_objective
self.init_objective = init_objective
self.criterion = criterion
self.epochs = epochs
self.approx = approx
self.norm = norm
self.bsz = bsz
def backward(self, model, step_fn, **opt_kwargs):
"""Compute meta gradients wrt dec and code
Arguments:
model (Warp): warped model to backprop through.
step_fn (function): step function for the meta gradient.
**opt_kwargs (kwargs): optional arguments to inner optimizer.
"""
out = model.buffer.dataset
if len(out) == 2:
optimizer_buffers = None
data, params = out
else:
data, params, optimizer_buffers = out
warp_objective = WARP_OBJECTIVES[self.warp_objective]
warp_objective(model, self.criterion, params, optimizer_buffers, data,
step_fn, opt_kwargs, self.epochs, self.bsz, self.approx)
init_objective= INIT_OBJECTIVES[self.init_objective]
init_objective(model.named_init_parameters(suffix=None),
params, self.norm, self.bsz, step_fn)
def warp_on_same_loss(model, criterion, trj, brj, tds, step_fn,
opt_kwargs, epochs, bsz, approx):
"""WarpGrad uses same objective in first and second step."""
unfreeze(model.meta_parameters(include_init=False))
unfreeze(model.adapt_parameters())
model.train()
def _get(t, i):
state = trj[t][i]
buffer = brj[t][i] if brj else None
# pylint: disable=unbalanced-tuple-unpacking
(x, y), (x2, y2) = get_data(tds[t], 2)
return x, y, state, buffer, x2, y2
def _step(batch):
loss = 0
for (x, y, state, buffer, x2, y2) in batch:
model.set_state(state)
opt = SGD(model.optimizer_parameter_groups(tensor=True),
**opt_kwargs)
opt.zero_grad()
if buffer:
for p, b in zip(
model.optimizer_parameter_groups(tensor=True), buffer):
opt.state[p] = b
if approx:
l1 = a1 = None
l2, a2 = approx_step(x, y, model, criterion, acc_fn)
else:
l2, (l1, a1, a2) = step(x, y, x2, y2, model,
opt, criterion, acc_fn)
del l1, a1, a2 # unused for now.
loss = loss + l2
loss = loss / bsz
backward(loss, model.meta_parameters(include_init=False))
step_fn()
for _ in range(epochs):
datapoints = [_get(t, i) for t in trj for i in range(len(trj[t]))]
np.random.shuffle(datapoints)
if bsz > 0:
for i in range(0, len(datapoints), bsz):
_step(datapoints[i:i+bsz])
else:
_step(datapoints)
freeze(model.meta_parameters(include_init=False))
def simplified_leap(named_init, trj, norm, bsz, step_fn):
"""One step of Leap over trajectories, wrt shared init.
Similar to Leap objective except the loss delta is omitted.
"""
del bsz # unused
# TODO: allow choice of cpu or gpu
par_names, init = zip(*named_init)
device = init[0].device
unfreeze(init)
with joblib.Parallel(n_jobs=-1, backend='threading') as parallel:
adds = parallel(
joblib.delayed(line_seg_len)(
trj[t][i], trj[t][i + 1], par_names, norm, device)
for t in trj
for i in range(0, len(trj[t])-1)
)
for i, a in zip(init, zip(*adds)):
a = torch.stack(a, dim=0) # pylint: disable=no-member
i.grad = a.data.sum(dim=0)
i.grad.div_(len(trj))
step_fn()
freeze(init)
def line_seg_len(entry_state, exit_state, par_names, norm, device):
"""Compute partial grad for line segment"""
entry_params = state_dict_to_par_list(entry_state, par_names)
exit_params = state_dict_to_par_list(exit_state, par_names)
add = [e.data.to(device) - x.data.to(device)
for e, x in zip(entry_params, exit_params)]
if norm:
norm = global_norm(add, detach=True, eps=1e-9)
for l in add:
l.data.div_(norm)
return add
def null_func(*args, **kwargs):
"""Vacuous call"""
del args, kwargs # unused.
return
WARP_OBJECTIVES = {
0: null_func,
1: warp_on_same_loss,
}
INIT_OBJECTIVES = {
0: null_func,
1: simplified_leap,
}
"""Utilities for WarpGrad.
:author: Sebastian Flennerhag
"""
from collections import OrderedDict
import os
import torch
from models._dict import load_state_dict
def grad(x, y, model, params, criterion):
"""Compute new parameters with computation graph intact.
Arguments:
x (torch.Tensor): input tensor.
y (torch.Tensor): target tensor.
model (Warp): warped model.
params (list): list of parameters to differentiate
the loss with respect to.
criterion (fun): task loss criterion.
"""
device = next(model.parameters()).device
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
p = model(x, cache_parameters=False)
loss = criterion(p, y)
return torch.autograd.grad(loss, params, create_graph=True)
def approx_step(x_outer, y_outer, model, criterion, scorer):
"""Compute approximate meta gradient (no inner step).
Arguments:
x_outer (torch.Tensor): input to obtain meta learner gradient.
y_outer (torch.Tensor): target to obtain meta learner gradient.
model (Warp): warped model.
criterion (fun): task loss criterion.
scorer (fun): scoring function (optional).
"""
device = next(model.parameters()).device
x_outer = x_outer.to(device, non_blocking=True)
y_outer = y_outer.to(device, non_blocking=True)
pred_outer = model(x_outer, cache_parameters=False)
loss_outer = criterion(pred_outer, y_outer)
score_outer = None
if scorer is not None:
score_outer = scorer(pred_outer.detach(), y_outer.detach())
return loss_outer, score_outer
def step(x_inner, y_inner, x_outer, y_outer,
model, optimizer, criterion, scorer):
"""Compute gradients of meta-parameters using the WarpGrad objective.
Arguments:
x_inner (torch.Tensor): input to obtain task learner gradient.
y_inner (torch.Tensor): target to obtain task learner gradient.
x_outer (torch.Tensor): input to obtain meta learner gradient.
y_outer (torch.Tensor): target to obtain meta learner gradient.
model (Warp): warped model.
optimizer (warpgrad.optim.SGD, warpgrad.optim.Adam): task optimizer,
must be differentiable.
criterion (fun): task loss criterion.
scorer (fun): scoring function (optional).
"""
device = next(model.parameters()).device
original_tparams = OrderedDict(model.named_adapt_parameters())
x_inner = x_inner.to(device, non_blocking=True)
y_inner = y_inner.to(device, non_blocking=True)
x_outer = x_outer.to(device, non_blocking=True)
y_outer = y_outer.to(device, non_blocking=True)
# Parameter update
pred_inner = model(x_inner, cache_parameters=False)
print("setp",pred_inner,y_inner)
loss_inner = criterion(pred_inner, y_inner)
backward(loss_inner, model.adapt_parameters(), create_graph=True)
_, new_params = optimizer.step(retain_graph=True)
replace(model.model, new_params, original_tparams)
# Get forward loss
pred_outer = model(x_outer, cache_parameters=False)
loss_outer = criterion(pred_outer, y_outer)
# Reset model parameters
load_state_dict(model.model, original_tparams)
score_inner = None
score_outer = None
if scorer is not None:
score_inner = scorer(pred_inner.detach(), y_inner.detach())
score_outer = scorer(pred_outer.detach(), y_outer.detach())
return loss_outer, (loss_inner.detach(), score_inner, score_outer)
def replace(model, new_params, old_params):
"""Helper for updating model dict in a back-prop compatible way."""
par_names = list(old_params.keys())
assert len(par_names) == len(new_params)
new_state = OrderedDict(zip(par_names, new_params))
load_state_dict(model, new_state)
for p in old_params.values():
p.grad = None # drop current gradients to avoid accumulating into them
def backward(loss, args, create_graph=False):
"""Partial derivatives of loss wrt args."""
args = list(args)
grads = torch.autograd.grad(loss, args, create_graph=create_graph)
for p, g in zip(args, grads):
p.backward(g)
#############################################################################
def get_data(iterator, n_iterations):
"""Helper for setting up data."""
out = []
iterator.dataset.train()
for i, batch in enumerate(iterator):
out.append(batch)
if i+1 == n_iterations:
break
return out
def freeze(iterable):
"""Freeze params in module."""
for p in iterable:
p.requires_grad = False
def unfreeze(iterable):
"""Freeze params in module."""
for p in iterable:
p.requires_grad = True
def get_groups(parameters, opt_params, tensor=True):
"""Return layer-wise optimization hyper-parameters."""
groups = []
for p, lr, mom in zip(parameters, opt_params.lr, opt_params.momentum):
if not tensor:
lr = lr.item()
mom = mom.item()
groups.append({'params': p, 'lr': lr, 'momentum': mom})
return groups
#############################################################################
def stem(fpath):
"""Returns task-id and iter-id."""
fname = str(os.path.basename(fpath).split('.')[0])
return fname.split('_')
def load(path):
"""Load stored data in to task dict."""
files = sorted(os.listdir(path)) # sorting -> iter order
mapped_files = {stem(f)[0]: [] for f in files}
for fname in files:
fpath = os.path.join(path, fname)
n, i = stem(fname)
assert len(mapped_files) == int(i)
mapped_files[n].append(torch.load(fpath))
return mapped_files
def clear(path):
"""delete all files in path."""
files = os.listdir(path)
for f in files:
os.unlink(os.path.join(path, f))
def state_dict_to_par_list(state_dict, par_names):
"""Prune a state_dict and return a list of parameters."""
return [tensor for name, tensor in state_dict.items() if
name in par_names]
def clone(tensor, device=None):
"""Clone a list of tensors."""
if not isinstance(tensor, torch.Tensor):
return [clone(t) for t in tensor]
cloned = tensor.detach().clone()
cloned.requires_grad = tensor.requires_grad
if device is not None:
cloned = cloned.to(device)
return cloned
def clone_state(state_dict, *args, **kwargs):
"""Clone a list of tensors."""
cloned_state = OrderedDict()
for n, p in state_dict.items():
cloned_state[n] = clone(p, *args, **kwargs)
return cloned_state
def copy_opt(param_states):
"""Copy buffers from an optimizer state."""
cloned_states = []
for param_state in param_states:
cloned_state = OrderedDict()
for k, v in param_state.items():
cloned_state[k] = v.clone().cpu()
cloned_states.append(cloned_state)
return cloned_states
def copy(to_tensors, from_tensors):
"""Copy tensor data from one set of iterables to another."""
if isinstance(to_tensors, (list, tuple)):
for p, q in zip(to_tensors, from_tensors):
p.data.copy_(q.data)
elif isinstance(to_tensors, (dict, OrderedDict)):
for (n, p), (m, q) in zip(to_tensors.items(), from_tensors.items()):
if n != m:
raise ValueError(
'target state variable {}'
'does not match source state variable{}'.format(n, m))
p.data.copy_(q.data)
else:
raise ValueError('Unknown iterables type {}'.format(type(to_tensors)))
def zero_grad(tensor_like):
"""Set tensor gradient to zero. Null op if argument is not a tensor.
Argument:
tensor_like: objects to zero grad for.
If list, will iterate over elements.
"""
if isinstance(tensor_like, (tuple, list)):
for p in tensor_like:
zero_grad(p)
if not hasattr(tensor_like, 'grad'):
return
if tensor_like.grad is None:
if tensor_like.dim() == 0:
tensor_like.grad = tensor_like.detach().clone()
else:
tensor_like.grad = tensor_like.new(*tensor_like.shape)
tensor_like.grad.zero_()
#############################################################################
def n_correct(logits, targets):
"""Number correct predictions.
Args:
logits (torch.Tensor): tensor of prediction logits.
targets (torch.Tensor): tensor of class targets.
"""
_, predictions = logits.max(1)
correct = (predictions == targets).sum().item()
return correct
def acc_fn(p, y):
"""Accuracy of discrete predictions."""
return n_correct(p, y) / y.size(0)
def global_norm(tensors, detach=True, eps=1e-9):
"""Compute a global norm over a list of tensors."""
norm = 0.
for tensor in tensors:
tensor = tensor.view(-1)
if detach:
tensor = tensor.detach().data
norm += torch.dot(tensor, tensor) # pylint: disable=no-member
norm = norm.sqrt()
return norm + eps
"""Runtime helpers"""
# pylint: disable=invalid-name,too-many-arguments,too-many-instance-attributes
import os
from os.path import join
import numpy as np
def convert_arg(arg):
"""Convert string to type"""
# pylint: disable=broad-except
if arg.lower() == 'none':
arg = None
elif arg.lower() == 'false':
arg = False
elif arg.lower() == 'true':
arg = True
elif '.' in arg:
try:
arg = float(arg)
except Exception:
pass
else:
try:
arg = int(arg)
except Exception:
pass
return arg
def build_kwargs(args):
"""Build a kwargs dict from a list of key-value pairs"""
kwargs = {}
if not args:
return kwargs
assert len(args) % 2 == 0, "argument list %r does not appear to have key, value pairs" % args
while args:
k = args.pop(0)
v = args.pop(0)
if ':' in v:
v = tuple(convert_arg(a) for a in v.split(':'))
else:
v = convert_arg(v)
kwargs[str(k)] = v
return kwargs
def compute_ncorrect(p, y):
"""Accuracy over a tensor of predictions"""
_, p = p.max(1)
correct = (p == y).sum().item()
return correct
def compute_auc(x):
"""Compute AUC (composite trapezoidal rule)"""
T = len(x)
v = 0
for i in range(1, T):
v += ((x[i] - x[i-1]) / 2 + x[i-1]) / T
return v
def unlink(path):
"""Unlink logfiles"""
for f in os.listdir(path):
f = os.path.join(path, f)
if f.endswith('.log'):
os.unlink(f)
###############################################################################
def write(step, meta_loss, loss, accuracy, losses, accuracies, f):
"""Write results data to file"""
lstr = ""
for l in losses:
lstr += "{:f};".format(l)
astr = ""
for a in accuracies:
astr += "{:f};".format(a)
msg = "{:d},{:f},{:f},{:f},{:s},{:s}\n".format(
step, meta_loss, loss, accuracy, lstr, astr)
with open(f, 'a') as fo:
fo.write(msg)
def log_status(results, idx, time):
"""Print status"""
#pylint: disable=unbalanced-tuple-unpacking,too-many-star-expressions
print("[{:9s}] time:{:3.3f} "
"train: outer={:0.4f} inner={:0.4f} acc={:2.2f} ".format(
str(idx),
time,
results.train_meta_loss,
results.train_loss,
results.train_acc)
)
# print("[{:9s}] time:{:3.3f} "
# "train: outer={:0.4f} inner={:0.4f} acc={:2.2f} "
# "val: outer={:0.4f} inner={:0.4f} acc={:2.2f}".format(
# str(idx),
# time,
# results.train_meta_loss,
# results.train_loss,
# results.train_acc,
# results.val_meta_loss,
# results.val_loss,
# results.val_acc)
# )
def write_train_res(results, step, log_dir):
"""Write results from a meta-train step to file"""
write(step,
results.train_meta_loss,
results.train_loss,
results.train_acc,
results.train_losses,
results.train_accs,
join(log_dir, 'results_train_train.log'))
# write(step,
# results.val_meta_loss,
# results.val_loss,
# results.val_acc,
# results.val_losses,
# results.val_accs,
# join(log_dir, 'results_train_val.log'))
def write_val_res(results, step, case, log_dir):
"""Write task results data to file"""
for task_id, res in enumerate(results):
write(step,
res.train_meta_loss,
res.train_loss,
res.train_acc,
res.train_losses,
res.train_accs,
join(log_dir, 'results_{}_{}_train.log'.format(task_id, case)))
write(step,
res.val_meta_loss,
res.val_loss,
res.val_acc,
res.val_losses,
res.val_accs,
join(log_dir, 'results_{}_{}_val.log'.format(task_id, case)))
###############################################################################
class Res:
"""Results container
Attributes:
losses (list): list of losses over batch iterator
accs (list): list of accs over batch iterator
meta_loss (float): auc over losses
loss (float): mean loss over losses. Call ``aggregate`` to compute.
acc (float): mean acc over accs. Call ``aggregate`` to compute.
"""
def __init__(self):
self.losses = []
self.accs = []
self.ncorrects = []
self.nsamples = []
self.meta_loss = 0
self.loss = 0
self.acc = 0
def log(self, loss, pred, target):
"""Log loss and accuracies"""
nsamples = target.size(0)
ncorr = compute_ncorrect(pred.data, target.data)
accuracy = ncorr / target.size(0)
self.losses.append(loss)
self.ncorrects.append(ncorr)
self.nsamples.append(nsamples)
self.accs.append(accuracy)
def aggregate(self):
"""Compute aggregate statistics"""
self.accs = np.array(self.accs)
self.losses = np.array(self.losses)
self.nsamples = np.array(self.nsamples)
self.ncorrects = np.array(self.ncorrects)
self.loss = self.losses.mean()
self.meta_loss = compute_auc(self.losses)
self.acc = self.ncorrects.sum() / self.nsamples.sum()
class AggRes:
"""Results aggregation container
Aggregates results over a mini-batch of tasks
"""
def __init__(self, results):
# self.train_res, self.val_res = zip(*results)
self.train_res=results
self.aggregate_train()
# self.aggregate_val()
def aggregate_train(self):
"""Aggregate train results"""
(self.train_meta_loss,
self.train_loss,
self.train_acc,
self.train_losses,
self.train_accs) = self.aggregate(self.train_res)
# def aggregate_val(self):
# """Aggregate val results"""
# (self.val_meta_loss,
# self.val_loss,
# self.val_acc,
# self.val_losses,
# self.val_accs) = self.aggregate(self.val_res)
@staticmethod
def aggregate(results):
"""Aggregate losses and accs across Res instances"""
agg_losses = np.stack([res.losses for res in results], axis=1)
agg_ncorrects = np.stack([res.ncorrects for res in results], axis=1)
agg_nsamples = np.stack([res.nsamples for res in results], axis=1)
mean_loss = agg_losses.mean()
mean_losses = agg_losses.mean(axis=1)
mean_meta_loss = compute_auc(mean_losses)
mean_acc = agg_ncorrects.sum() / agg_nsamples.sum()
mean_accs = agg_ncorrects.sum(axis=1) / agg_nsamples.sum(axis=1)
return mean_meta_loss, mean_loss, mean_acc, mean_losses, mean_accs
def consolidate(agg_res):
"""Merge a list of agg_res into one agg_res"""
results = [sum((r.train_res, r.val_res), ()) for r in agg_res]
return AggRes(results)
"""Base Omniglot models. Based on original implementation:
https://github.com/amzn/metalearn-leap
"""
import torch.nn as nn
from torch import optim
from torchvision import models
import numpy as np
from models.wrapper import WarpGradWrapper
import maml
from models.ResNet import ResNetFc
NUM_CLASSES = 50
ACT_FUNS = {
'none': None,
'leakyrelu': nn.LeakyReLU,
'relu': nn.ReLU,
'sigmoid': nn.Sigmoid,
'tanh': nn.Tanh
}
def get_model(args, criterion):
"""Construct model from main args"""
kwargs = dict(num_classes=args.classes,
num_layers=args.num_layers,
kernel_size=args.kernel_size,
num_filters=args.num_filters,
imsize=args.imsize,
padding=args.padding,
batch_norm=args.batch_norm,
multi_head=args.multi_head)
if "warp" in args.meta_model.lower():
# model = WarpedOmniConv(warp_num_layers=args.warp_num_layers,
# warp_num_filters=args.warp_num_filters,
# warp_residual_connection=args.warp_residual,
# warp_act_fun=args.warp_act_fun,
# warp_batch_norm=args.warp_batch_norm,
# warp_final_head=args.warp_final_head,
# **kwargs)
model=WarpedResNet50()
if "maml" in args.meta_model.lower():
model=ResNetFc("ResNet50", use_bottleneck=True, bottleneck_dim=2560, new_cls=True, class_num=10)
# else:
# model = OmniConv(**kwargs)
if args.cuda:
model = model.cuda()
if args.log_ival > 0:
print(model)
if "warp" in args.meta_model.lower():
return WarpGradWrapper(
model,
args.inner_opt,
args.outer_opt,
args.inner_kwargs,
args.outer_kwargs,
args.meta_kwargs,
criterion)
# if args.meta_model.lower() == 'leap':
# return LeapWrapper(
# model,
# args.inner_opt,
# args.outer_opt,
# args.inner_kwargs,
# args.outer_kwargs,
# args.meta_kwargs,
# criterion,
# )
# if args.meta_model.lower() == 'no':
# return NoWrapper(
# model,
# args.inner_opt,
# args.inner_kwargs,
# criterion,
# )
# if args.meta_model.lower() == 'ft':
# return FtWrapper(
# model,
# args.inner_opt,
# args.inner_kwargs,
# criterion,
# )
# if args.meta_model.lower() == 'fomaml':
# return FOMAMLWrapper(
# model,
# args.inner_opt,
# args.outer_opt,
# args.inner_kwargs,
# args.outer_kwargs,
# criterion,
# )
# if args.meta_model.lower() == 'reptile':
# return ReptileWrapper(
# model,
# args.inner_opt,
# args.outer_opt,
# args.inner_kwargs,
# args.outer_kwargs,
# criterion,
# )
if args.meta_model.lower() == 'maml':
return MAMLWrapper(
model,
args.inner_opt,
args.outer_opt,
args.inner_kwargs,
args.outer_kwargs,
criterion,
)
# NotImplementedError('Meta-learner {} unknown.'.format(
# args.meta_model.lower()))
###############################################################################
class UnSqueeze(nn.Module):
"""Create channel dim if necessary."""
def __init__(self):
super(UnSqueeze, self).__init__()
def forward(self, input):
"""Creates channel dimension on a 3-d tensor.
Null-op if input is a 4-d tensor.
Arguments:
input (torch.Tensor): tensor to unsqueeze.
"""
if input.dim() == 4:
return input
return input.unsqueeze(1)
class Squeeze(nn.Module):
"""Undo excess dimensions"""
def __init__(self): # pylint: disable=useless-super-delegation
super(Squeeze, self).__init__()
def forward(self, input):
"""Squeeze singular dimensions of an input tensor.
Arguments:
input (torch.Tensor): tensor to unsqueeze.
"""
if input.size(0) != 0:
return input.squeeze()
input = input.squeeze()
return input.view(1, *input.size())
class Linear(nn.Module):
"""Wrapper around torch.nn.Linear to deal with single/multi-headed mode.
Arguments:
multi_head (bool): multi-headed mode.
num_features_in (int): number of features in input.
num_features_out (int): number of features in output.
**kwargs: optional arguments to pass to torch.nn.Linear.
"""
def __init__(self, multi_head, num_features_in,
num_features_out, **kwargs):
super(Linear, self).__init__()
self.num_features_in = num_features_in
self.num_features_out = num_features_out
self.multi_head = multi_head
def _linear_factory():
return nn.Linear(num_features_in, num_features_out, **kwargs)
if self.multi_head:
self.linear = nn.ModuleList([_linear_factory()] * NUM_CLASSES)
else:
self.linear = _linear_factory()
def forward(self, x, idx=None):
if self.multi_head:
assert idx is not None, "Pass head idx in multi-headed mode."
return self.linear[idx](x)
return self.linear(x)
def reset_parameters(self):
"""Reset parameters if in multi-headed mode."""
if self.multi_head:
for lin in self.linear:
lin.reset_parameters()
else:
self.linear.reset_parameters()
###############################################################################
# class OmniConv(nn.Module):
# """ConvNet classifier.
# Arguments:
# num_classes (int): number of classes to predict in each alphabet
# num_layers (int): number of convolutional layers (default=4).
# kernel_size (int): kernel size in each convolution (default=3).
# num_filters (int): number of output filters in each convolution
# (default=64)
# imsize (tuple): tuple of image height and width dimension.
# padding (bool, int, tuple): padding argument to convolution layers
# (default=True).
# batch_norm (bool): use batch normalization in each convolution layer
# (default=True).
# multi_head (bool): multi-headed training (default=False).
# """
# def __init__(self, num_classes, num_layers=4, kernel_size=3,
# num_filters=64, imsize=(28, 28), padding=True,
# batch_norm=True, multi_head=False):
# super(OmniConv, self).__init__()
# self.num_layers = num_layers
# self.kernel_size = kernel_size
# self.num_filters = num_filters
# self.imsize = imsize
# self.batch_norm = batch_norm
# self.multi_head = multi_head
# def conv_block(nin):
# block = [nn.Conv2d(nin, num_filters, kernel_size, padding=padding),
# nn.MaxPool2d(2)]
# if batch_norm:
# block.append(nn.BatchNorm2d(num_filters))
# block.append(nn.ReLU())
# return block
# layers = [UnSqueeze()]
# for i in range(num_layers):
# layers.extend(conv_block(1 if i == 0 else num_filters))
# layers.append(Squeeze())
# self.conv = nn.Sequential(*layers)
# self.head = Linear(self.multi_head, num_filters, num_classes)
# def forward(self, input, idx=None):
# input = self.conv(input)
# return self.head(input, idx)
# def init_adaptation(self):
# """Reset stats for new task"""
# # Reset if multi-head, otherwise null-op
# self.head.reset_parameters()
# # Reset BN running stats
# for m in self.modules():
# if hasattr(m, 'reset_running_stats'):
# m.reset_running_stats()
###############################################################################
class WarpLayer(nn.Module):
"""Warp-layer module.
Allows flexible configuration of convolutional warp-layers.
Arguments:
num_features_in (int): number of input filters.
num_features_out (int): number of output filters.
kernel_size (int): kernel size in each convolution (default=3).
padding (bool, int, tuple): padding argument to convolution layer.
residual_connection (bool): add residual connection.
batch_norm (bool): use batch normalization in warp-layer.
act_fun (fun): non-linearity in warp-layer (optional).
"""
def __init__(self, num_features_in, num_features_out,
kernel_size, padding, residual_connection,
batch_norm, act_fun):
super(WarpLayer, self).__init__()
self.residual_connection = residual_connection
self.bn_in = None
self.bn_out = None
if batch_norm:
self.bn_in = nn.BatchNorm2d(num_features_in)
if self.residual_connection:
self.bn_out = nn.BatchNorm2d(num_features_out)
self.conv = nn.Conv2d(num_features_in,
num_features_out,
kernel_size,
padding=padding)
self.act_fun = act_fun if act_fun is None else act_fun()
if residual_connection and num_features_in != num_features_out:
self.scale = nn.Conv2d(num_features_in, num_features_out, 1)
else:
self.scale = None
def forward(self, x):
h = x
if self.bn_in is not None:
h = self.bn_in(h)
h = self.conv(h)
if self.act_fun is not None:
h = self.act_fun(h)
if self.residual_connection:
if self.scale is not None:
x = self.scale(x)
h = x + h
if self.bn_out is not None:
h = self.bn_out(h)
return h
def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)
class WarpedResNet50(nn.Module):
"""ConvNet classifier.
Same as the OmniConv except for additional warp-layers.
Arguments:
num_classes (int): number of classes to predict in each alphabet
num_layers (int): number of convolutional layers (default=4).
kernel_size (int): kernel size in each convolution (default=3).
num_filters (int): number of output filters in each convolution
(default=64)
imsize (tuple): tuple of image height and width dimension.
padding (bool, int, tuple): padding argument to convolution layers
(default=True).
batch_norm (bool): use batch normalization in each convolution layer
(default=True).
multi_head (bool): multi-headed training (default=False).
warp_num_layers (int): number of warp-layers per adaptable conv block
(default=1).
warp_num_filters number of output filters internally in warp-layers,
if `warp_num_lavers>1`. Final number of output filters of
warp-layers are always same as number of input filters
(default=64).
warp_residual_connection (bool): use residual connection in
warp-layers (default=False).
warp_act_fun (str): activation function in warp-layers (optional).
warp_batch_norm (bool): activation function in warp-layer.
warp_final_head (bool): add a warp-layer to final output of model.
"""
def __init__(self,
num_classes=10,
num_layers=5,
kernel_size=3,
num_filters=[64, 128, 256, 512],#resnet 18
#num_filters=[256, 512, 1024, 2048], #resnet 50 101
imsize=(28, 28),
padding=True,
batch_norm=True,
multi_head=False,
resnet_name="ResNet50",
warp_num_layers=1,
warp_num_filters=64,
warp_residual_connection=False,
warp_act_fun=None,
use_bottleneck=False,
bottleneck_dim=256,
new_cls=False,
class_num=1000,
warp_batch_norm=False,
warp_final_head=False):
super(WarpedResNet50, self).__init__()
resnet_dict = {"ResNet18": models.resnet18, "ResNet34": models.resnet34, "ResNet50": models.resnet50,
"ResNet101": models.resnet101, "ResNet152": models.resnet152}
self.num_layers = num_layers
self.kernel_size = kernel_size
self.num_filters = num_filters
self.imsize = imsize
# self.batch_norm = batch_norm
self.multi_head = multi_head
self.warp_num_layers = warp_num_layers
self.warp_num_filters = warp_num_filters
self.warp_residual_connection = warp_residual_connection
self.warp_act_fun = ACT_FUNS["leakyrelu"]
self.warp_batch_norm = warp_batch_norm
self.warp_final_head = warp_final_head
self._conv_counter = 0
self._warp_counter = 0
self.num_filters = num_filters
model_resnet = resnet_dict[resnet_name](pretrained=True)
# self.conv1 = model_resnet.conv1
# self.bn1 = model_resnet.bn1
# self.relu = model_resnet.relu
# self.maxpool = model_resnet.maxpool
# self.layer1 = model_resnet.layer1
# self.layer2 = model_resnet.layer2
# self.layer3 = model_resnet.layer3
# self.layer4 = model_resnet.layer4
# self.avgpool = model_resnet.avgpool
# self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, \
# self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool)
def init_weights(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
nn.init.kaiming_uniform_(m.weight)
nn.init.zeros_(m.bias)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.zeros_(m.bias)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
self.use_bottleneck = use_bottleneck
# self.new_cls = new_cls
# if new_cls:
# if self.use_bottleneck:
# self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim)
# self.fc = nn.Linear(bottleneck_dim, class_num)
# self.bottleneck.apply(init_weights)
# self.fc.apply(init_weights)
# self.__in_features = bottleneck_dim
# else:
# self.fc = nn.Linear(model_resnet.fc.in_features, class_num)
# self.fc.apply(init_weights)
# self.__in_features = model_resnet.fc.in_features
# else:
# self.fc = model_resnet.fc
# self.__in_features = model_resnet.fc.in_features
# def conv_block(nin):
# # Task adaptable conv block, same as OmniConv
# _block = [nn.Conv2d(nin,
# num_filters,
# kernel_size,
# padding=padding),
# nn.MaxPool2d(2)]
# if batch_norm:
# _block.append(nn.BatchNorm2d(num_filters))
# _block.append(nn.ReLU())
# return nn.Sequential(*_block)
def warp_layer(nin, nout):
# We use same kernel_size and padding as OmniConv for simplicity
return WarpLayer(nin, nout, kernel_size, padding,
self.warp_residual_connection,
self.warp_batch_norm,
self.warp_act_fun)
def block(nin,last_layer):
# Main block, wraps warp_layers around a conv_block.
# Task-adaptable layer
self._conv_counter += 1
if nin == 1:
conv1 = getattr(model_resnet, "conv1")
bn1 = getattr(model_resnet, "bn1")
relu = getattr(model_resnet, "relu")
maxpool = getattr(model_resnet, "maxpool")
layer = nn.Sequential(*[conv1, bn1, relu, maxpool])
else:
layer = getattr(model_resnet, "layer{}".format(self._conv_counter-1))
setattr(self, 'conv{}'.format(self._conv_counter), layer)
# Warp-layers
# nin = num_filters[self._conv_counter]
# for _ in range(self.warp_num_layers):
# self._warp_counter = \
# self._warp_counter % self.warp_num_layers + 1
#
# if self._warp_counter == self.warp_num_layers:
# nout = num_filters
# else:
# nout = self.warp_num_filters
#
# setattr(self, 'warp{}{}'.format(self._conv_counter,
# self._warp_counter),
# warp_layer(nin, nout))
#
# nin = nout
if nin==1:
nin = 64
nout = 64
else:
# nin = num_filters[self._conv_counter-1]
nout = nin
self._warp_counter = 1
setattr(self, 'warp{}{}'.format(self._conv_counter,
self._warp_counter),
warp_layer(nin, nout))
# Build model
block(1,last_layer=False)
for index,num_filter in enumerate( num_filters):
block(num_filter,last_layer=(index==len(num_filters)-1))
if self.warp_final_head:
self.head = Linear(self.multi_head, num_filters[-1], num_filters)
self.warp_head = nn.Linear(num_filters, num_classes)
else:
self.head = Linear(self.multi_head, 8192, num_classes) # resnet50 101 32768 18:8192
self.squeeze = Squeeze()
def forward(self, x, idx=None):
"""Forward-pass through model."""
# x=x.repeat(1,3,1,1)
for i in range(1, self._conv_counter + 1):
# Task-adaptable layer
x = getattr(self, 'conv{}'.format(i))(x)
# print("x.shape: ",x.shape)
# Warp-layer(s)
# for j in range(1, self._warp_counter+1):
# x = getattr(self, 'warp{}{}'.format(i, j))(x)
x = getattr(self, 'warp{}{}'.format(i, 1))(x)
emb = self.squeeze(x)
emb=emb.reshape(emb.size(0),-1)
logits = self.head(emb, idx)
if self.warp_final_head:
return self.warp_head(x)
return emb,logits
def adapt_modules(self):
"""Iterator for task-adaptable modules"""
for i in range(1, self.num_layers + 1):
conv = getattr(self, 'conv{}'.format(i))
yield conv
yield self.head
def warp_modules(self):
"""Iterator for warp-layer modules"""
for i in range(1, self.num_layers + 1):
for j in range(1, self.warp_num_layers + 1):
warp = getattr(self, 'warp{}{}'.format(i, j))
yield warp
if self.warp_final_head:
yield self.warp_head
def init_adaptation(self):
"""Reset stats for new task"""
# Reset head if multi-headed, otherwise null-op
self.head.reset_parameters()
# Reset BN running stats
for m in self.modules():
if hasattr(m, 'reset_running_stats'):
m.reset_running_stats()
class MAMLWrapper(object):
"""Wrapper around the MAML meta-learner.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
meta_optimizer_cls: meta optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon
construction.
criterion (func): loss criterion to use.
"""
def __init__(self, model, optimizer_cls, meta_optimizer_cls,
optimizer_kwargs, meta_optimizer_kwargs, criterion):
self.criterion = criterion
self.model = model
self.optimizer_cls = \
maml.SGD if optimizer_cls.lower() == 'sgd' else maml.Adam
self.meta = maml.MAML(optimizer_cls=self.optimizer_cls,
criterion=criterion,
model=model,
tensor=False,
**optimizer_kwargs)
self.meta_optimizer_cls = \
optim.SGD if meta_optimizer_cls.lower() == 'sgd' else optim.Adam
self.optimizer_kwargs = optimizer_kwargs
self.meta_optimizer = self.meta_optimizer_cls(self.meta.parameters(),
**meta_optimizer_kwargs)
def __call__(self, meta_batch, meta_train):
tasks = []
for t in meta_batch:
t.dataset.train()
inner = [b for b in t]
t.dataset.train()
outer = [b for b in t]
tasks.append((inner, outer))
return self.run_meta_batch(tasks, meta_train=meta_train)
def run_meta_batch(self, meta_batch, meta_train):
"""Run on meta-batch.
Arguments:
meta_batch (list): list of task-specific dataloaders
meta_train (bool): meta-train on batch.
"""
loss, results = self.meta(meta_batch,
return_predictions=False,
return_results=True,
create_graph=meta_train)
if meta_train:
loss.backward()
self.meta_optimizer.step()
self.meta_optimizer.zero_grad()
return results
"""Warped Gradient Descent.
Model wrapper that implements the WarpGrad logic
on a generic PyTorch model.
:author: Sebastian Flennerhag
"""
# pylint: disable=too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=redefined-builtin
# pylint: disable=too-many-instance-attributes
# pylint: disable=protected-access
# pylint: disable=too-many-locals
# pylint: disable=arguments-differ
from collections import OrderedDict
import os
import uuid
import tempfile
import torch
from models.utils_warp import (copy, copy_opt, clone_state, load, clear,
unfreeze, freeze, zero_grad, get_groups)
class ReplayBuffer:
"""Cache for parameters during meta-training."""
def __init__(self, inmem=True, tmpdir=None):
"""Initialize replay buffer.
Arguments:
inmem (bool): in-memory buffer (on CPU, default=True).
tmpdir (str): if not inmem, root of buffer (optional).
"""
self.inmem = inmem
self._data_buffer = {}
self._state_buffer = {}
self._optimizer_buffer = {}
self._idx = {}
if not inmem and tmpdir is None:
tmpdir = tempfile.mkdtemp('_WGDTMP')
self.tmpdir = tmpdir
def clear(self):
"""Clear buffer."""
self._data_buffer.clear()
self._idx.clear()
if self.inmem:
self._state_buffer.clear()
self._optimizer_buffer.clear()
else:
clear(self.tmpdir)
def init(self, slot, data):
"""Initialize slot in buffer and attach dataloader."""
if slot in self._idx:
raise ValueError('slot {} already in buffer'.format(slot))
self._idx[slot] = 0
self._data_buffer[slot] = data
def update(self, slot, state, buffer=None):
"""Persists a copy of current parameters under given slot.
Arguments:
slot (str): a group-level identifier for the parameters
(i.e. task id).
state(OrderedDict, torch.Tensor): state_dict to add to buffer.
buffer (list, dict, None): list of optimiser parameter buffers
(default=None).
"""
assert slot in self._idx, 'slot not in buffer. Call init_slot first'
self._idx[slot] += 1
if self.inmem:
if slot not in self._state_buffer:
assert self._idx[slot] == 1
self._state_buffer[slot] = []
if buffer is not None:
self._optimizer_buffer[slot] = []
self._state_buffer[slot].append(clone_state(state, device='cpu'))
if buffer is not None:
self._optimizer_buffer[slot].append(copy_opt(buffer))
return
if buffer is not None:
raise NotImplementedError(
"Putting optimizer parameters on disk not implemented.")
fname = '{}_{}.{}'.format(slot, self._idx[slot], '.tar')
fpath = os.path.join(self.tmpdir, fname)
torch.save(state, fpath)
@property
def dataset(self):
"""Current replay buffer."""
if self.inmem:
if self._optimizer_buffer:
return (self._data_buffer, self._state_buffer,
self._optimizer_buffer)
return self._data_buffer, self._state_buffer
param_cache = load(self.tmpdir)
return self._data_buffer, param_cache
class OptimizerParameters:
"""Container for Optimizer Parameters."""
def __init__(self, trainable, default_lr, default_momentum):
"""Initialize optimizer parameter manager.
Arguments:
trainable (bool): whether optimizer parameters are trainable.
default_lr (float): initial learning rate (prior to training).
default_momentum (float): initial momentum (prior to training).
"""
self._opt = None
self._trainable = trainable
self._param_names = []
self._lr = []
self._momentum = []
self.default_lr = default_lr
self.default_momentum = default_momentum
def init(self, named_parameters):
"""Initialize opt parameter groups for parameters."""
# pylint: disable=not-callable
self._lr = []
self._momentum = []
self._param_names = []
for n, p in named_parameters:
self._param_names.append(n)
pl = torch.tensor(self.default_lr,
device=p.device,
requires_grad=self.trainable)
pm = torch.tensor(self.default_momentum,
device=p.device,
requires_grad=self.trainable)
self._lr.append(pl)
self._momentum.append(pm)
@property
def lr(self):
"""Learning rates."""
for l in self._lr:
if l.item() < 0:
l.data.fill_(1e-6)
yield l
@property
def momentum(self):
"""Momentum rates."""
for m in self._momentum:
if m.item() < 0:
m.data.fill_(1e-6)
yield m
@property
def trainable(self):
"""Trainable parameters flag."""
return self._trainable
@trainable.setter
def trainable(self, trainable):
self._trainable = trainable
if self._trainable and self.default_momentum == 0:
self.default_momentum = 1e-4
for m in self._momentum:
m.data.fill_(self.default_momentum)
for p in self.parameters():
p.requires_grad = self._trainable
def parameters(self):
"""Optimizer parameters."""
for p in self._lr + self._momentum:
yield p
def named_parameters(self):
"""Optimizer parameters."""
for n, p in zip(self._param_names, self._lr):
n += '.lr'
yield n, p
for n, p in zip(self._param_names, self._momentum):
n += '.mom'
yield n, p
def groups(self, parameters, tensor):
"""Parameter groups."""
return get_groups(parameters, self, tensor=tensor)
class Parameters:
"""Attributes for parameter partitioning."""
def __init__(self, model, adapt_modules, warp_modules,
optimizer_parameters):
"""Initialize partitioning.
Arguments:
model (torch.nn.Module): main model.
adapt_modules (list, tuple): adaptable modules.
warp_modules (list, tuple): warp modules.
optimizer_parameters (OptimizerParameters): optimizer parameters
manager.
"""
self.model = model
self.adapt_modules = adapt_modules
self.warp_modules = warp_modules
self._optimizer = None
self._learn_optimizer = optimizer_parameters.trainable
self._optimizer_parameters = optimizer_parameters
self._optimizer_parameters.init(self.named_adapt_parameters())
self._init_state = clone_state(self.adapt_state())
self._init_parameters = [(n, p) for n, p in self._init_state.items()
if p.requires_grad]
def set_parameters(self, new_parameters):
"""Set task parameters to new_params.
Arguments:
new_parameters (list, torch.Tensor): list of task parameters.
"""
copy(self.adapt_parameters(), new_parameters)
def set_state(self, new_state):
"""Set task parameters to new_params.
Arguments:
new_state (OrderedDict): state dictionary over task parameters
and buffers.
"""
copy(self.adapt_state(), new_state)
def init_state(self):
return self._init_state
def adapt_state(self):
"""Return state_dict for adapt modules."""
model_state = self.model.state_dict(keep_vars=True)
adapt_tensors = [id(t) for m in self.adapt_modules
for t in m.state_dict(keep_vars=True).values()]
return OrderedDict((n, t) for n, t in model_state.items()
if id(t) in adapt_tensors)
def adapt_parameters(self):
"""Adapt parameters."""
for m in self.adapt_modules:
for p in m.parameters():
yield p
def named_adapt_parameters(self):
"""Named adapt parameters."""
# We can't use adapt_modules.named_parameters()
# need to go through main model to get correct names
adapt_ids = list(map(id, self.adapt_parameters()))
for n, p in self.model.named_parameters():
if id(p) in adapt_ids:
yield n, p
def parameters(self):
"""All parameters."""
return self.model.parameters()
def optimizer_buffer(self):
"""Return stored optimizer buffer, if any."""
buffer = None
if self._optimizer is not None:
# opt.state is not ordered in pytorch v1
buffer = []
param_names = [n for n, _ in self.adapt_state()]
for n in param_names:
# check since opt.state is a dict factory
if n in self._optimizer.state:
buffer.append(self._optimizer.state[n])
return buffer
def optimizer_parameters(self):
"""Optimizer parameters."""
return self._optimizer_parameters.parameters()
def named_optimizer_parameters(self):
"""Named optimizer parameters."""
return self._optimizer_parameters.named_parameters()
def init_parameters(self):
"""Initialization parameters."""
for _, p in self.named_init_parameters():
yield p
def named_init_parameters(self, suffix='.init'):
"""Named initialization parameters."""
for n, p in self._init_parameters:
if suffix is not None:
n += suffix
yield n, p
def warp_parameters(self):
"""Warp parameters."""
for m in self.warp_modules:
for p in m.parameters():
yield p
def named_warp_parameters(self, suffix=None):
"""Named warp parameters."""
# We can't use warp_modules.named_parameters()
# need to go through main model to get correct names
meta_param_ids = list(map(id, self.warp_parameters()))
for n, p in self.model.named_parameters():
if id(p) in meta_param_ids:
if suffix is not None:
n += suffix
yield n, p
def meta_parameters(self,
include_warp=True,
include_init=True,
include_optimizer=True):
"""All meta-parameters.
Arguments:
include_warp (bool): include warp parameters.
include_init (bool): include the initialization.
include_optimizer (bool): include optimizer parameters.
"""
if self.learn_optimizer and include_optimizer:
for p in self.optimizer_parameters():
yield p
if include_init:
for p in self.init_parameters():
yield p
if include_warp:
for p in self.warp_parameters():
yield p
def named_meta_parameters(self,
include_warp=True,
include_init=True,
include_opt=True):
"""Named meta parameters.
Arguments:
include_warp (bool): include warp parameters.
include_init (bool): include the initialization.
include_opt (bool): include optimizer parameters
(if `learn_optimizer=True`).
"""
if self.learn_optimizer and include_opt:
for n, p in self.named_optimizer_parameters():
yield n, p
if include_init:
for n, p in self.named_init_parameters():
yield n, p
if include_warp:
for n, p in self.named_warp_parameters():
yield n, p
def optimizer_parameter_groups(self, tensor=False):
"""Parameter groups for optimizer.
Arguments:
tensor (bool): return parameters as tensors
(use with warpgrad.optim).
"""
return self._optimizer_parameters.groups(
self.adapt_parameters(), tensor)
def register_optimizer(self, optimizer):
"""Register an optimizer during task training to collect buffers."""
self._optimizer = optimizer
def unregister_optimizer(self):
"""Unregister optimizer to stop collecting buffers."""
self._optimizer = None
@property
def learn_optimizer(self):
"""Learn optimizer parameters."""
return self._optimizer_parameters.trainable
@learn_optimizer.setter
def learn_optimizer(self, learn_optimizer):
self._optimizer_parameters.trainable = learn_optimizer
class Warp(Parameters):
"""Model wrapper for WarpGrad."""
def __init__(self, model, adapt_modules, warp_modules,
updater, buffer, optimizer_parameters):
"""Initialize warp over given model.
Args:
model (torch.nn.Module): main model.
adapt_modules (torch.nn.Module): adapt modules in main model.
warp_modules (torch.nn.Module): warp modules in main model.
updater (updater.DualUpdater): the meta parameter update handler.
buffer (ReplayBuffer): adapt parameter replay buffer.
optimizer_parameters (OptimizerParameters): handler of optimizer
parameters.
"""
super(Warp, self).__init__(model,
adapt_modules,
warp_modules,
optimizer_parameters)
self.updater = updater
self._task = None
self._collect = True
self.buffer = buffer
self.zero_meta_grads()
self.zero_task_grads()
def __call__(self, *inputs, cache_parameters=None):
if cache_parameters is None:
cache_parameters = self._collect
if cache_parameters:
self._dump()
return self.model(*inputs)
def register_task(self, data):
"""Register a distinct task in buffer.
Args:
data: the tasks data generator.
"""
self._task = uuid.uuid4().hex
self.buffer.init(self._task, data)
def init_adaptation(self, reset_adapt_parameters=None):
"""Calls init_adaptation on model.
Arguments:
reset_adapt_parameters (bool): whether to reset the initialisation
of adaptable parameters. If not specified, will be reset if
the initialization is meta-learned (in the `updater`).
"""
self.model.init_adaptation()
if reset_adapt_parameters is None:
# Will be 0 if no meta-objective for initialization is specified
reset_adapt_parameters = self.updater.init_objective
if reset_adapt_parameters:
copy(self.adapt_state(), self.init_state())
freeze(self.meta_parameters())
unfreeze(self.adapt_parameters())
self.model.train()
def clear(self):
"""Clears parameter trajectory buffer."""
self.buffer.clear()
def collect(self):
"""Switch on task parameter collection in buffer."""
self._collect = True
def no_collect(self):
"""Switch off task parameter collection in buffer."""
self._collect = False
def train(self):
"""Switch to train mode in task learner."""
self.model.train()
def eval(self):
"""Switch to eval mode in task learner."""
self.model.eval()
def zero_meta_grads(self):
"""Set meta gradients to zero."""
zero_grad(list(self.meta_parameters()))
def zero_task_grads(self):
"""Set task learner gradients to zero."""
zero_grad(list(self.adapt_parameters()))
def backward(self, *args, retain_trajectories=False,
retain_optimizer=False, **kwargs):
"""Compute gradients of meta-parameters.
Arguments:
*args: arguments to pass to the updater.
retain_trajectories (bool): keep current buffer (default=False).
retain_optimizer (bool): keep registered optimizer (default=False).
**kwargs: keyword arguments to pass to the updater.
"""
collecting = self.collecting
if collecting:
self.no_collect()
self.updater.backward(self, *args, **kwargs)
if not retain_trajectories:
self.clear()
if not retain_optimizer:
self.unregister_optimizer()
if collecting:
self.collect()
def _dump(self):
"""Persists a copy of current parameters under [task, iter]."""
self.buffer.update(self._task,
self.adapt_state(),
self.optimizer_buffer())
@property
def collecting(self):
"""Flag for whether we are collect parameters in buffer."""
return self._collect
@property
def dataset(self):
"""Return copy of codes."""
return self.buffer.dataset
"""Runtime helpers"""
# pylint: disable=invalid-name,too-many-arguments,too-many-instance-attributes
import os
from os.path import join
import numpy as np
def convert_arg(arg):
"""Convert string to type"""
# pylint: disable=broad-except
if arg.lower() == 'none':
arg = None
elif arg.lower() == 'false':
arg = False
elif arg.lower() == 'true':
arg = True
elif '.' in arg:
try:
arg = float(arg)
except Exception:
pass
else:
try:
arg = int(arg)
except Exception:
pass
return arg
def build_kwargs(args):
"""Build a kwargs dict from a list of key-value pairs"""
kwargs = {}
if not args:
return kwargs
assert len(args) % 2 == 0, "argument list %r does not appear to have key, value pairs" % args
while args:
k = args.pop(0)
v = args.pop(0)
if ':' in v:
v = tuple(convert_arg(a) for a in v.split(':'))
else:
v = convert_arg(v)
kwargs[str(k)] = v
return kwargs
def compute_ncorrect(p, y):
"""Accuracy over a tensor of predictions"""
_, p = p.max(1)
correct = (p == y).sum().item()
return correct
def compute_auc(x):
"""Compute AUC (composite trapezoidal rule)"""
T = len(x)
v = 0
for i in range(1, T):
v += ((x[i] - x[i-1]) / 2 + x[i-1]) / T
return v
def unlink(path):
"""Unlink logfiles"""
for f in os.listdir(path):
f = os.path.join(path, f)
if f.endswith('.log'):
os.unlink(f)
###############################################################################
def write(step, meta_loss, loss, accuracy, losses, accuracies, f):
"""Write results data to file"""
lstr = ""
for l in losses:
lstr += "{:f};".format(l)
astr = ""
for a in accuracies:
astr += "{:f};".format(a)
msg = "{:d},{:f},{:f},{:f},{:s},{:s}\n".format(
step, meta_loss, loss, accuracy, lstr, astr)
with open(f, 'a') as fo:
fo.write(msg)
def log_status(results, idx, time):
"""Print status"""
#pylint: disable=unbalanced-tuple-unpacking,too-many-star-expressions
print("[{:9s}] time:{:3.3f} "
"train: outer={:0.4f} inner={:0.4f} acc={:2.2f} ".format(
str(idx),
time,
results.train_meta_loss,
results.train_loss,
results.train_acc)
)
# print("[{:9s}] time:{:3.3f} "
# "train: outer={:0.4f} inner={:0.4f} acc={:2.2f} "
# "val: outer={:0.4f} inner={:0.4f} acc={:2.2f}".format(
# str(idx),
# time,
# results.train_meta_loss,
# results.train_loss,
# results.train_acc,
# results.val_meta_loss,
# results.val_loss,
# results.val_acc)
# )
def write_train_res(results, step, log_dir):
"""Write results from a meta-train step to file"""
write(step,
results.train_meta_loss,
results.train_loss,
results.train_acc,
results.train_losses,
results.train_accs,
join(log_dir, 'results_train_train.log'))
# write(step,
# results.val_meta_loss,
# results.val_loss,
# results.val_acc,
# results.val_losses,
# results.val_accs,
# join(log_dir, 'results_train_val.log'))
def write_val_res(results, step, case, log_dir):
"""Write task results data to file"""
for task_id, res in enumerate(results):
write(step,
res.train_meta_loss,
res.train_loss,
res.train_acc,
res.train_losses,
res.train_accs,
join(log_dir, 'results_{}_{}_train.log'.format(task_id, case)))
write(step,
res.val_meta_loss,
res.val_loss,
res.val_acc,
res.val_losses,
res.val_accs,
join(log_dir, 'results_{}_{}_val.log'.format(task_id, case)))
###############################################################################
class Res:
"""Results container
Attributes:
losses (list): list of losses over batch iterator
accs (list): list of accs over batch iterator
meta_loss (float): auc over losses
loss (float): mean loss over losses. Call ``aggregate`` to compute.
acc (float): mean acc over accs. Call ``aggregate`` to compute.
"""
def __init__(self):
self.losses = []
self.accs = []
self.ncorrects = []
self.nsamples = []
self.meta_loss = 0
self.loss = 0
self.acc = 0
def log(self, loss, pred, target):
"""Log loss and accuracies"""
nsamples = target.size(0)
ncorr = compute_ncorrect(pred.data, target.data)
accuracy = ncorr / target.size(0)
self.losses.append(loss)
self.ncorrects.append(ncorr)
self.nsamples.append(nsamples)
self.accs.append(accuracy)
def aggregate(self):
"""Compute aggregate statistics"""
self.accs = np.array(self.accs)
self.losses = np.array(self.losses)
self.nsamples = np.array(self.nsamples)
self.ncorrects = np.array(self.ncorrects)
self.loss = self.losses.mean()
self.meta_loss = compute_auc(self.losses)
self.acc = self.ncorrects.sum() / self.nsamples.sum()
class AggRes:
"""Results aggregation container
Aggregates results over a mini-batch of tasks
"""
def __init__(self, results):
# self.train_res, self.val_res = zip(*results)
self.train_res=results
self.aggregate_train()
# self.aggregate_val()
def aggregate_train(self):
"""Aggregate train results"""
(self.train_meta_loss,
self.train_loss,
self.train_acc,
self.train_losses,
self.train_accs) = self.aggregate(self.train_res)
# def aggregate_val(self):
# """Aggregate val results"""
# (self.val_meta_loss,
# self.val_loss,
# self.val_acc,
# self.val_losses,
# self.val_accs) = self.aggregate(self.val_res)
@staticmethod
def aggregate(results):
"""Aggregate losses and accs across Res instances"""
agg_losses = np.stack([res.losses for res in results], axis=1)
agg_ncorrects = np.stack([res.ncorrects for res in results], axis=1)
agg_nsamples = np.stack([res.nsamples for res in results], axis=1)
mean_loss = agg_losses.mean()
mean_losses = agg_losses.mean(axis=1)
mean_meta_loss = compute_auc(mean_losses)
mean_acc = agg_ncorrects.sum() / agg_nsamples.sum()
mean_accs = agg_ncorrects.sum(axis=1) / agg_nsamples.sum(axis=1)
# print(mean_meta_loss, mean_loss, mean_acc)
return mean_meta_loss, mean_loss, mean_acc, mean_losses, mean_accs
def consolidate(agg_res):
"""Merge a list of agg_res into one agg_res"""
results = [sum((r.train_res, r.val_res), ()) for r in agg_res]
return AggRes(results)
"""Meta-learners for Omniglot experiment.
Based on original implementation:
https://github.com/amzn/metalearn-leap
"""
import random
from abc import abstractmethod
from torch import nn
from torch import optim
import torch
import numpy as np
from copy import deepcopy
from scipy.spatial.distance import cdist
import warpgrad
from leap import Leap
from leap.utils import clone_state_dict
from models.warpmain_utils import Res, AggRes
class BaseWrapper(object):
"""Generic training wrapper.
Arguments:
criterion (func): loss criterion to use.
model (nn.Module): classifier.
optimizer_cls: optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
"""
def __init__(self, criterion, model, optimizer_cls, optimizer_kwargs):
self.criterion = criterion
self.model = model
self.optimizer_cls = \
optim.SGD if optimizer_cls.lower() == 'sgd' else optim.Adam
self.optimizer_kwargs = optimizer_kwargs
def __call__(self, tasks, meta_train=True):
return self.run_tasks(tasks, meta_train=meta_train)
@abstractmethod
def _partial_meta_update(self, loss, final):
"""Meta-model specific meta update rule.
Arguments:
loss (nn.Tensor): loss value for given mini-batch.
final (bool): whether iteration is the final training step.
"""
NotImplementedError('Implement in meta-learner class wrapper.')
@abstractmethod
def _final_meta_update(self):
"""Meta-model specific meta update rule."""
NotImplementedError('Implement in meta-learner class wrapper.')
def run_tasks(self, tasks, meta_train):
"""Train on a mini-batch tasks and evaluate test performance.
Arguments:
tasks (list, torch.utils.data.DataLoader): list of task-specific
dataloaders.
meta_train (bool): whether current run in during meta-training.
"""
# assert len(tasks)==2
results = []
# print(tasks)
for task in tasks:
# print("task",task)
task.dataset.train()
trainres = self.run_task(task, train=True, meta_train=meta_train)
# task.dataset.eval()
# valres = self.run_task(task, train=False, meta_train=False)
results.append(trainres)
# break # the first task is unsupervised training
##
results = AggRes(results)
# Meta gradient step
if meta_train:
self._final_meta_update()
return results
def run_task(self, task, train, meta_train):
"""Run model on a given task.
Arguments:
task (torch.utils.data.DataLoader): task-specific dataloaders.
train (bool): whether to train on task.
meta_train (bool): whether to meta-train on task.
"""
optimizer = None
if train:
self.model.init_adaptation()
self.model.train()
optimizer = self.optimizer_cls(
self.model.parameters(), **self.optimizer_kwargs)
else:
self.model.eval()
return self.run_batches(
task, optimizer, train=train, meta_train=meta_train)
def build_global_center(self,batches,device):
start_test = True
with torch.no_grad():
for n, (inputs, labels) in enumerate(batches):
# data = iter_test.next()
inputs = inputs.to(device)
feas,logits = self.model(inputs)
outputs = logits
if start_test:
all_fea = feas.float().cpu()
all_output = outputs.float().cpu()
all_label = labels.float()
start_test = False
else:
all_fea = torch.cat((all_fea, feas.float().cpu()), 0)
all_output = torch.cat((all_output, outputs.float().cpu()), 0)
all_label = torch.cat((all_label, labels.float()), 0)
all_output = nn.Softmax(dim=1)(all_output)
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
all_fea = all_fea.float().cpu().numpy()
K = all_output.size(1)
aff = all_output.float().cpu().numpy()
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
dd = cdist(all_fea, initc, 'cosine')
pred_label = dd.argmin(axis=1)
acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
aff = np.eye(K)[pred_label]
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
center = torch.from_numpy(initc).cuda()
log_str = 'predict acc to cluster acc = {:.2f}% -> {:.2f}%'.format(accuracy * 100, acc * 100)
# args.out_file.write(log_str + '\n')
# args.out_file.flush()
print(log_str + '\n')
# self.global_center=center
return center,pred_label
def obtain_center(self,data):
start_test = True
with torch.no_grad():
inputs = data[0]
labels = data[1]
feas,logits = self.model(inputs)
outputs = logits
all_fea = feas.float().cpu()
all_output = outputs.float().cpu()
all_label = labels.float().cpu()
all_output = nn.Softmax(dim=1)(all_output)
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
all_fea = torch.cat((all_fea, torch.ones(all_fea.size(0), 1)), 1)
all_fea = (all_fea.t() / torch.norm(all_fea, p=2, dim=1)).t()
all_fea = all_fea.float().cpu().numpy()
K = all_output.size(1)
aff = all_output.float().cpu().numpy()
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
dd = cdist(all_fea, initc, 'cosine')
pred_label = dd.argmin(axis=1)
acc = np.sum(pred_label == all_label.float().numpy()) / len(all_fea)
aff = np.eye(K)[pred_label]
initc = aff.transpose().dot(all_fea)
initc = initc / (1e-8 + aff.sum(axis=0)[:, None])
center = torch.from_numpy(initc).cuda()
log_str = 'predict_acc = {:.2f}% -> cluster_acc {:.2f}%'.format(accuracy * 100, acc * 100)
# args.out_file.write(log_str + '\n')
# args.out_file.flush()
# print(log_str + '\n')
return center
def obtain_label(self,features_target, center,device=None):
features_target = torch.cat((features_target, torch.ones(features_target.size(0), 1).cuda()), 1)
fea = features_target.float().detach().cpu().numpy()
center = center.float().detach().cpu().numpy()
dis = cdist(fea, center, 'cosine') + 1
pred = np.argmin(dis, axis=1)
pred = torch.from_numpy(pred).cuda()
return pred
def run_batches(self, batches, optimizer, train=False, meta_train=False):
"""Iterate over task-specific batches.
Arguments:
batches (torch.utils.data.DataLoader): task-specific dataloaders.
optimizer (torch.nn.optim): optimizer instance if training is True.
train (bool): whether to train on task.
meta_train (bool): whether to meta-train on task.
"""
device = next(self.model.parameters()).device
# t1,t2=[],[]
# for n, (input, target) in enumerate(batches):
# t1.append(target)
# for n, (input, target) in enumerate(batches):
# t2.append(target)
# print("t1",t1,"\nt2",t2)
res = Res()
N = len(batches)
center, _= self.build_global_center(batches,device)
optimizer.zero_grad()
for n, (input, target) in enumerate(batches):
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# Evaluate model
_,prediction = self.model(input)
with torch.no_grad():
features_support,_ = self.model(input)
pred = self.obtain_label(features_support, center,None)
loss = self.criterion(prediction, pred)
res.log(loss=loss.item(), pred=prediction, target=target)
# TRAINING #
if not train:
continue
# final = (n+1) == N
loss.backward()
# if meta_train:
# self._partial_meta_update(loss, final)
optimizer.step()
optimizer.zero_grad()
# if final:
# break
###
res.aggregate()
return res
def get_para_num(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
class WarpGradWrapper(BaseWrapper):
"""Wrapper around WarpGrad meta-learners.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
meta_optimizer_cls: meta optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon
construction.
meta_kwargs (dict): kwargs to pass to meta-learner upon construction.
criterion (func): loss criterion to use.
"""
def __init__(self,
model,
optimizer_cls,
meta_optimizer_cls,
optimizer_kwargs,
meta_optimizer_kwargs,
meta_kwargs,
criterion):
replay_buffer = warpgrad.ReplayBuffer(
inmem=meta_kwargs.pop('inmem', True),
tmpdir=meta_kwargs.pop('tmpdir', None))
optimizer_parameters = warpgrad.OptimizerParameters(
trainable=meta_kwargs.pop('learn_opt', False),
default_lr=optimizer_kwargs['lr'],
default_momentum=optimizer_kwargs['momentum']
if 'momentum' in optimizer_kwargs else 0.)
updater = warpgrad.DualUpdater(criterion, **meta_kwargs)
# p = get_para_num(model)
# print("before warp",p)
model = warpgrad.Warp(model=model,
adapt_modules=list(model.adapt_modules()),
warp_modules=list(model.warp_modules()),
updater=updater,
buffer=replay_buffer,
optimizer_parameters=optimizer_parameters)
# p = get_para_num(model)
# print("after warp", p)
# total_num = sum(p.numel() for p in model.parameters())
# trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print('Total', total_num, 'Trainable', trainable_num)
super(WarpGradWrapper, self).__init__(criterion,
model,
optimizer_cls,
optimizer_kwargs)
self.meta_optimizer_cls = optim.SGD \
if meta_optimizer_cls.lower() == 'sgd' else optim.Adam
lra = meta_optimizer_kwargs.pop(
'lr_adapt', meta_optimizer_kwargs['lr'])
lri = meta_optimizer_kwargs.pop(
'lr_init', meta_optimizer_kwargs['lr'])
lrl = meta_optimizer_kwargs.pop(
'lr_lr', meta_optimizer_kwargs['lr'])
self.meta_optimizer = self.meta_optimizer_cls(
[{'params': self.model.init_parameters(), 'lr': lri},
{'params': self.model.warp_parameters(), 'lr': lra},
{'params': self.model.optimizer_parameters(), 'lr': lrl}],
**meta_optimizer_kwargs)
def _partial_meta_update(self, loss, final):
pass
def _final_meta_update(self):
def step_fn():
self.meta_optimizer.step()
self.meta_optimizer.zero_grad()
self.model.backward(step_fn, **self.optimizer_kwargs)
def run_task(self, task, train, meta_train):
"""Run model on a given task, first adapting and then evaluating"""
if meta_train and train:
# Register new task in buffer.
self.model.register_task(task)
self.model.collect()
else:
# Make sure we're not collecting non-meta-train data
self.model.no_collect()
optimizer = None
if train:
# Initialize model adaptation
self.model.init_adaptation()
optimizer = self.optimizer_cls(
self.model.optimizer_parameter_groups(),
**self.optimizer_kwargs)
if self.model.collecting and self.model.learn_optimizer:
# Register optimiser to collect potential momentum buffers
self.model.register_optimizer(optimizer)
else:
self.model.eval()
return self.run_batches(
task, optimizer, train=train, meta_train=meta_train)
class LeapWrapper(BaseWrapper):
"""Wrapper around the Leap meta-learner.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
meta_optimizer_cls: meta optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon
construction.
meta_kwargs (dict): kwargs to pass to meta-learner upon construction.
criterion (func): loss criterion to use.
"""
def __init__(self,
model,
optimizer_cls,
meta_optimizer_cls,
optimizer_kwargs,
meta_optimizer_kwargs,
meta_kwargs,
criterion):
super(LeapWrapper, self).__init__(criterion,
model,
optimizer_cls,
optimizer_kwargs)
self.meta = Leap(model, **meta_kwargs)
self.meta_optimizer_cls = \
optim.SGD if meta_optimizer_cls.lower() == 'sgd' else optim.Adam
self.meta_optimizer = self.meta_optimizer_cls(
self.meta.parameters(), **meta_optimizer_kwargs)
def _partial_meta_update(self, l, final):
self.meta.update(l, self.model)
def _final_meta_update(self):
self.meta.normalize()
self.meta_optimizer.step()
self.meta_optimizer.zero_grad()
def run_task(self, task, train, meta_train):
if meta_train:
self.meta.init_task()
if train:
self.meta.to(self.model)
return super(LeapWrapper, self).run_task(task, train, meta_train)
class MAMLWrapper(object):
"""Wrapper around the MAML meta-learner.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
meta_optimizer_cls: meta optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon
construction.
criterion (func): loss criterion to use.
"""
def __init__(self, model, optimizer_cls, meta_optimizer_cls,
optimizer_kwargs, meta_optimizer_kwargs, criterion):
self.criterion = criterion
self.model = model
self.optimizer_cls = \
maml.SGD if optimizer_cls.lower() == 'sgd' else maml.Adam
self.meta = maml.MAML(optimizer_cls=self.optimizer_cls,
criterion=criterion,
model=model,
tensor=False,
**optimizer_kwargs)
self.meta_optimizer_cls = \
optim.SGD if meta_optimizer_cls.lower() == 'sgd' else optim.Adam
self.optimizer_kwargs = optimizer_kwargs
self.meta_optimizer = self.meta_optimizer_cls(self.meta.parameters(),
**meta_optimizer_kwargs)
def __call__(self, meta_batch, meta_train):
tasks = []
for t in meta_batch:
t.dataset.train()
inner = [b for b in t]
t.dataset.train()
outer = [b for b in t]
tasks.append((inner, outer))
return self.run_meta_batch(tasks, meta_train=meta_train)
def run_meta_batch(self, meta_batch, meta_train):
"""Run on meta-batch.
Arguments:
meta_batch (list): list of task-specific dataloaders
meta_train (bool): meta-train on batch.
"""
loss, results = self.meta(meta_batch,
return_predictions=False,
return_results=True,
create_graph=meta_train)
if meta_train:
loss.backward()
self.meta_optimizer.step()
self.meta_optimizer.zero_grad()
return results
class NoWrapper(BaseWrapper):
"""Wrapper for baseline without any meta-learning.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
criterion (func): loss criterion to use.
"""
def __init__(self, model, optimizer_cls, optimizer_kwargs, criterion):
super(NoWrapper, self).__init__(criterion,
model,
optimizer_cls,
optimizer_kwargs)
self._original = clone_state_dict(model.state_dict(keep_vars=True))
def __call__(self, tasks, meta_train=False):
return super(NoWrapper, self).__call__(tasks, meta_train=False)
def run_task(self, task, train, meta_train):
if train:
self.model.load_state_dict(self._original)
return super(NoWrapper, self).run_task(task, train, meta_train)
def _partial_meta_update(self, loss, final):
pass
def _final_meta_update(self):
pass
class _FOWrapper(BaseWrapper):
"""Base wrapper for First-order MAML and Reptile.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
meta_optimizer_cls: meta optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon
construction.
criterion (func): loss criterion to use.
"""
_all_grads = None
def __init__(self, model, optimizer_cls, meta_optimizer_cls,
optimizer_kwargs, meta_optimizer_kwargs, criterion):
super(_FOWrapper, self).__init__(criterion,
model,
optimizer_cls,
optimizer_kwargs)
self.meta_optimizer_cls = \
optim.SGD if meta_optimizer_cls.lower() == 'sgd' else optim.Adam
self.meta_optimizer_kwargs = meta_optimizer_kwargs
self._counter = 0
self._updates = None
self._original = clone_state_dict(
self.model.state_dict(keep_vars=True))
params = [p for p in self._original.values()
if getattr(p, 'requires_grad', False)]
self.meta_optimizer = self.meta_optimizer_cls(params,
**meta_optimizer_kwargs)
def run_task(self, task, train, meta_train):
if meta_train:
self._counter += 1
if train:
self.model.load_state_dict(self._original)
return super(_FOWrapper, self).run_task(task, train, meta_train)
def _partial_meta_update(self, loss, final):
if not final:
return
if self._updates is None:
self._updates = {}
for n, p in self._original.items():
if not getattr(p, 'requires_grad', False):
continue
if p.size():
self._updates[n] = p.new(*p.size()).zero_()
else:
self._updates[n] = p.clone().zero_()
for n, p in self.model.state_dict(keep_vars=True).items():
if n not in self._updates:
continue
if self._all_grads is True:
self._updates[n].add_(p.data)
else:
self._updates[n].add_(p.grad.data)
def _final_meta_update(self):
for n, p in self._updates.items():
p.data.div_(self._counter)
for n, p in self._original.items():
if n not in self._updates:
continue
if self._all_grads:
p.grad = p.data - self._updates[n].data
else:
p.grad = self._updates[n]
self.meta_optimizer.step()
self.meta_optimizer.zero_grad()
self._counter = 0
self._updates = None
class ReptileWrapper(_FOWrapper):
"""Wrapper for Reptile.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
meta_optimizer_cls: meta optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon
construction.
criterion (func): loss criterion to use.
"""
_all_grads = True
def __init__(self, *args, **kwargs):
super(ReptileWrapper, self).__init__(*args, **kwargs)
class FOMAMLWrapper(_FOWrapper):
"""Wrapper for FOMAML.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
meta_optimizer_cls: meta optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon
construction.
criterion (func): loss criterion to use.
"""
_all_grads = False
def __init__(self, *args, **kwargs):
super(FOMAMLWrapper, self).__init__(*args, **kwargs)
class FtWrapper(BaseWrapper):
"""Wrapper for Multi-headed finetuning.
This wrapper differs from others in that it blends batches from all tasks
into a single epoch.
Arguments:
model (nn.Module): classifier.
optimizer_cls: optimizer class.
optimizer_kwargs (dict): kwargs to pass to optimizer upon construction.
criterion (func): loss criterion to use.
"""
def __init__(self, model, optimizer_cls, optimizer_kwargs, criterion):
super(FtWrapper, self).__init__(criterion,
model,
optimizer_cls,
optimizer_kwargs)
# We use the same inner optimizer throughout
self.optimizer = self.optimizer_cls(self.model.parameters(),
**self.optimizer_kwargs)
@staticmethod
def gen_multitask_batches(tasks, train):
"""Generates one batch iterator across all tasks."""
iterator_id = 0
all_batches = []
for task_id, iterator in tasks:
if train:
iterator.dataset.train()
else:
iterator.dataset.eval()
for batch in iterator:
all_batches.append((iterator_id, task_id, batch))
iterator_id += 1
if train:
random.shuffle(all_batches)
return all_batches
def run_tasks(self, tasks, meta_train):
original = None
if not meta_train:
original = clone_state_dict(self.model.state_dict(keep_vars=True))
# Non-transductive task evaluation for fair comparison
for module in self.model.modules():
if hasattr(module, 'reset_running_stats'):
module.reset_running_stats()
# Training #
all_batches = self.gen_multitask_batches(tasks, train=True)
trainres = self.run_multitask(all_batches, train=True)
# Eval #
all_batches = self.gen_multitask_batches(tasks, train=False)
valres = self.run_multitask(all_batches, train=False)
results = AggRes(zip(trainres, valres))
if not meta_train:
self.model.load_state_dict(original)
return results
def _partial_meta_update(self, l, final):
return
def _final_meta_update(self):
return
def run_multitask(self, batches, train):
"""Train on task in multi-task mode
This is equivalent to the run_task method but differs in that
batches are assumed to be mixed from different tasks.
"""
N = len(batches)
if train:
self.model.train()
else:
self.model.eval()
device = next(self.model.parameters()).device
res = {}
for n, (iterator_id, task_id, (input, target)) in enumerate(batches):
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
prediction = self.model(input, task_id)
loss = self.criterion(prediction, target)
if iterator_id not in res:
res[iterator_id] = Res()
res[iterator_id].log(loss=loss.item(),
pred=prediction,
target=target)
# TRAINING #
if not train:
continue
final = (n + 1) == N
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
if final:
break
###
res = [r[1] for r in sorted(res.items(), key=lambda r: r[0])]
for r in res:
r.aggregate()
return res
import os
import time
import torch
import numpy as np
import json
import _pickle
import math
from multiprocessing import Pool
import fcntl
import random
import csv
import pynvml,time
def getAvaliableDevice(gpu=[1,2,3,4,0,5],min_mem=18000,left=False):
# def getAvaliableDevice(gpu=[6],min_mem=10000,left=False):
"""
:param gpu:
:param min_mem:
:param left:
:return:
"""
# return 0
pynvml.nvmlInit()
t=int(time.strftime("%H", time.localtime()))
if t>=23 or t <8:
left=False # do not leave any GPUs
#else:
#left=True
min_num=3
dic = {0: 5, 1: 0, 2: 1, 3: 2, 4: 3, 5: 4,-1: -1} # just for 120 server
ava_gpu = -1
while ava_gpu == -1:
avaliable=[]
for i in gpu:
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
# handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
# print((utilization.gpu))
if (meminfo.free / 1024 ** 2)>min_mem:
avaliable.append(dic[i])
# elif i ==0 and (meminfo.free / 1024 ** 2)>16000:
# avaliable.append(dic[i])
elif (meminfo.free / 1024 ** 2)>16000 and utilization.gpu<20:
avaliable.append(dic[i])
if len(avaliable)==0 or (left and len(avaliable)<=1):
ava_gpu = -1
time.sleep(5)
continue
ava_gpu= avaliable[0]
return ava_gpu
# def write_shared_file(file_name,content):
# nowtime=time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
# content[0]=nowtime+" "+content[0]
# with open(file_name,'a+') as f:
# fcntl.flock(f,fcntl.LOCK_EX)
# f.writelines(content)
# fcntl.flock(f,fcntl.LOCK_UN)
def write_csv_file(file_name,content):
nowtime=time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
content["time"]=nowtime
to_write_head = False
if not os.path.exists(file_name):
to_write_head=True
with open(file_name,'a+') as f:
writer=csv.DictWriter(f,content.keys())
fcntl.flock(f,fcntl.LOCK_EX)
if to_write_head:
writer.writeheader()
writer.writerow(content)
# for key, value in content.items:
# writer.writerow([key, value])
fcntl.flock(f,fcntl.LOCK_UN)
import pandas as pd
def write_excel_file(path_root,content):
file=os.path.join(path_root, content["dataset"]+ str(content["net"])+ str(content["num"])+str(content["seed"])+ "shot.xlsx")
if not os.path.exists(file):
dff = pd.DataFrame(columns=["methods"])
dff.to_excel(file)
df=pd.read_excel(file, sheet_name='Sheet1')
task=content['s'].upper()[0]+"2"+ content['t'].upper()[0]
row=content["method"]
if row not in df["methods"]:
df.loc[row] = 0 # add a row
if task not in df.columns:
df[task] = 0 # add a colum
df.loc[row,task]=content["best_test_acc"]
df.to_excel(file)
def get_para_num(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def setup_seed(seed=0):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
# np.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def serialize(obj, path, in_json=False):
if isinstance(obj, np.ndarray):
np.save(path, obj)
elif in_json:
with open(path, "w") as file:
json.dump(obj, file, indent=2)
else:
with open(path, 'wb') as file:
_pickle.dump(obj, file)
def unserialize(path):
suffix = os.path.basename(path).split(".")[-1]
if suffix == "npy":
return np.load(path)
elif suffix == "json":
with open(path, "r") as file:
return json.load(file)
else:
with open(path, 'rb') as file:
return _pickle.load(file)
def set_gpu(x):
os.environ['CUDA_VISIBLE_DEVICES'] = x
print('using gpu:', x)
def check_dir(path):
'''
Create directory if it does not exist.
path: Path of directory.
'''
if not os.path.exists(path):
os.mkdir(path)
def uniform(size, tensor):
bound = 1.0 / math.sqrt(size)
if tensor is not None:
tensor.data.uniform_(-bound, bound)
def count_accuracy(logits, label):
pred = torch.argmax(logits, dim=1).view(-1)
label = label.view(-1)
accuracy = 100 * pred.eq(label).float().mean().item()
return accuracy
class Timer():
def __init__(self):
self.o = time.time()
def measure(self, p=1):
x = (time.time() - self.o) / float(p)
x = int(x)
if x >= 3600:
return '{:.1f}h'.format(x / 3600)
if x >= 60:
return '{}m'.format(round(x / 60))
return '{}s'.format(x)
import datetime
def log(log_file_path, string):
'''
Write one line of log into screen and file.
log_file_path: Path of log file.
string: String to write in log file.
'''
time=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(log_file_path, 'a+') as f:
f.write(string+" "+time + '\n')
f.flush()
print(string)
def store(st,writer,epoch=None):
update_step=len(st["loss"])
for step in range(update_step):
writer.add_scalars("l_s_s",{"loss":st["loss"][step],
"stop_gate":st["stop_gates"][step],
"scores":st["scores"][step]
},step)
for item in ["grads","input_gates","forget_gates"]:
for step in range(update_step):
d={}
for index,v in enumerate(st[item][step]):
d["layer"+str(index)]=v
writer.add_scalars(item, d, step)
def interleave_offsets(batch, nu):
groups = [batch // (nu + 1)] * (nu + 1)
for x in range(batch - sum(groups)):
groups[-x - 1] += 1
offsets = [0]
for g in groups:
offsets.append(offsets[-1] + g)
assert offsets[-1] == batch
return offsets
def interleave(xy, batch):
nu = len(xy) - 1
offsets = interleave_offsets(batch, nu)
xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
for i in range(1, nu + 1):
xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
return [torch.cat(v, dim=0) for v in xy]
def linear_rampup(current, rampup_length=0):
if rampup_length == 0:
return 1.0
else:
current = np.clip(current / rampup_length, 0.0, 1.0)
return float(current)
def exp_rampup(current, rampup_length=0):
if rampup_length == 0:
return 1.0
else:
current = np.clip(np.exp(2*current / rampup_length)-0.99, 0.0, 1.0)
return float(current)
# import torch
import torch.nn as nn
class Bn_Controller:
def __init__(self):
"""
freeze_bn and unfreeze_bn must appear in pairs
"""
self.backup = {}
def freeze_bn(self, model):
assert self.backup == {}
for name, m in model.named_modules():
if isinstance(m, nn.SyncBatchNorm) or isinstance(m, nn.BatchNorm2d):
self.backup[name + '.running_mean'] = m.running_mean.data.clone()
self.backup[name + '.running_var'] = m.running_var.data.clone()
self.backup[name + '.num_batches_tracked'] = m.num_batches_tracked.data.clone()
def unfreeze_bn(self, model):
for name, m in model.named_modules():
if isinstance(m, nn.SyncBatchNorm) or isinstance(m, nn.BatchNorm2d):
m.running_mean.data = self.backup[name + '.running_mean']
m.running_var.data = self.backup[name + '.running_var']
m.num_batches_tracked.data = self.backup[name + '.num_batches_tracked']
self.backup = {}
class EMA:
"""
Implementation from https://fyubang.com/2019/06/01/ema/
"""
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def load(self, ema_model):
for name, param in ema_model.named_parameters():
self.shadow[name] = param.data.clone()
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
class Get_Scalar:
def __init__(self, value):
self.value = value
def get_value(self, iter):
return self.value
def __call__(self, iter):
return self.value
import torch.nn.functional as F
def ce_loss(logits, targets, use_hard_labels=True, reduction='none'):
"""
wrapper for cross entropy loss in pytorch.
Args
logits: logit values, shape=[Batch size, # of classes]
targets: integer or vector, shape=[Batch size] or [Batch size, # of classes]
use_hard_labels: If True, targets have [Batch size] shape with int values. If False, the target is vector (default True)
"""
if use_hard_labels:
log_pred = F.log_softmax(logits, dim=-1)
return F.nll_loss(log_pred, targets, reduction=reduction)
# return F.cross_entropy(logits, targets, reduction=reduction) this is unstable
else:
assert logits.shape == targets.shape
log_pred = F.log_softmax(logits, dim=-1)
nll_loss = torch.sum(-targets * log_pred, dim=1)
return nll_loss
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(optimizer,
num_training_steps,
num_cycles=7. / 16.,
num_warmup_steps=0,
last_epoch=-1):
'''
Get cosine scheduler (LambdaLR).
if warmup is needed, set num_warmup_steps (int) > 0.
'''
def _lr_lambda(current_step):
'''
_lr_lambda returns a multiplicative factor given an interger parameter epochs.
Decaying criteria: last_epoch
'''
if current_step < num_warmup_steps:
_lr = float(current_step) / float(max(1, num_warmup_steps))
else:
num_cos_steps = float(current_step - num_warmup_steps)
num_cos_steps = num_cos_steps / float(max(1, num_training_steps - num_warmup_steps))
_lr = max(0.0, math.cos(math.pi * num_cycles * num_cos_steps))
return _lr
# print(":LLLL")
return torch.optim.lr_scheduler.LambdaLR(optimizer, _lr_lambda, last_epoch)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment