基本支持 生成对抗网络
 

甘 代表 生成对抗网 由伊恩·古德费洛(Ian Goodfellow)发明。概念是我们同时训练两个模型:生成器和评论器。生成器将尝试使新图像类似于数据集中的图像,而评论家将尝试从生成器的图像中对真实图像进行分类。生成器返回图像,批注者返回一个数字(通常为概率,伪图像为0,真实图像为1.)。

我们在每个步骤(或多或少)上都相互训练他们:

  1. 冻结生成器并通过以下步骤培训评论家:
    • getting one batch of true images (let's call that real)
    • 生成一批假图像 (let's call that fake)
    • 让评论家评估每批并从中计算损失函数;重要的是,它积极地奖励了对真实图像的检测,并对伪造的图像进行了惩罚
    • 使用此损失的梯度更新评论家的权重
  1. 冻结评论家并通过以下步骤对生成器进行培训:
    • 生成一批假图像
    • 评价评论家
    • 回报损失,对批评家以为这些是真实的图像给予了积极的奖励
    • 使用此损耗的梯度更新生成器的权重

包装模块

class 甘 Module [资源]

甘 Module(generator=None, critic=None, gen_mode=False ):: Module

Wrapper around a generator and a critic to create a 甘 .

This is just a shell to contain the two models. When called, it will either delegate the input to the generator or the critic depending of the value of gen_mode.

甘 Module.switch [资源]

甘 Module.switch(gen_mode=None)

Put the module in generator mode if gen_mode, in critic mode otherwise.

By default (leaving gen_mode to None), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).

basic_critic [资源]

basic_critic(in_size, n_channels, n_features=64, n_extra_layers=0, norm_type=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros')

A basic critic for images n_channels x in_size x in_size.

class AddChannels [资源]

AddChannels(n_dim ):: Module

Add n_dim channels at the end of the input.

basic_generator [资源]

basic_generator(out_size, n_channels, in_sz=100, n_features=64, n_extra_layers=0, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=<NormType.Batch: 1>, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros')

A basic generator from in_sz to images n_channels x out_size x out_size.

critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst =  甘 Module(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])

tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)

tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])

DenseResBlock [资源]

DenseResBlock(nf, norm_type=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros')

Resnet block of nf features. conv_kwargs are passed to conv_layer.

gan_critic [资源]

gan_critic(n_channels=3, nf=128, n_blocks=3, p=0.15)

Critic to train a .

class 甘 Loss [资源]

甘 Loss(gen_loss_func, crit_loss_func, gan_model ):: 甘 Module

Wrapper around crit_loss_func and gen_loss_func

In generator mode, this loss function expects the output of the generator and some target (a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func. This loss function has the following signature

def gen_loss_func(fake_pred, output, target):

to be able to combine the output of the critic on output (which the first argument fake_pred) with output and target (if you want to mix the 甘 loss with other losses for instance).

In critic mode, this loss function expects the real_pred given by the critic and some input (the noise fed to the generator). It will evaluate the critic using crit_loss_func. This loss function has the following signature

def crit_loss_func(real_pred, fake_pred):

where real_pred is the output of the critic on a batch of real images and fake_pred is generated from the noise using the generator.

class AdaptiveLoss [资源]

AdaptiveLoss(crit ):: Module

Expand the target to match the output size before applying crit.

accuracy_thresh_expand [资源]

accuracy_thresh_expand(y_pred, y_true, thresh=0.5, sigmoid=True)

Compute accuracy after expanding y_true to the size of y_pred.

用于GAN训练

set_freeze_model [资源]

set_freeze_model(m, rg)

class 甘 Trainer [资源]

甘 Trainer(switch_eval=False, clip=None, beta=0.98, gen_first=False, show_img=True ):: Callback

处理GAN培训。

class FixedGANSwitcher [资源]

FixedGANSwitcher(n_crit=1, n_gen=1 ):: Callback

Switcher to do n_crit iterations of the critic then n_gen iterations of the generator.

class AdaptiveGANSwitcher [资源]

AdaptiveGANSwitcher(gen_thresh=None, critic_thresh=None ):: Callback

Switcher that goes back to generator/critic when the loss goes below gen_thresh/crit_thresh.

class 甘 DiscriminativeLR [资源]

甘 DiscriminativeLR(mult_lr=5.0 ):: Callback

Callback that handles multiplying the learning rate by mult_lr for the critic.

甘 数据

class InvisibleTensor [资源]

InvisibleTensor(x, ** kwargs ):: TensorBase

generate_noise [资源]

generate_noise(fn, size=100)

bs = 128
size = 64
dblock =  数据 Block(blocks = (TransformBlock, ImageBlock),
                   get_x = generate_noise,
                   get_items = get_image_files,
                   splitter = IndexSplitter([]),
                   item_tfms=Resize(size, method=ResizeMethod.Crop), 
                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=16)

甘 学习者

gan_loss_from_func [资源]

gan_loss_from_func(loss_gen, loss_crit, weights_gen=None)

Define loss functions for a 甘 from loss_gen and loss_crit.

class 甘 Learner [资源]

甘 Learner(dls, generator, critic, gen_loss_func, crit_loss_func, switcher=None, gen_first=False, switch_eval=True, show_img=True, clip=None, cbs=None, metrics=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95) ):: 学习者

A 学习者 适用于GAN。

from  法泰 .callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn =  甘 Learner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)
learn.show_results(max_n=9, ds_idx=0)