V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
V2EX 提问指南
Waihinchan
V2EX  ›  问与答

关于重现 NEURAL TRANSFER 遇到的问题

  •  
  •   Waihinchan · 2020-08-07 08:07:23 +08:00 · 1164 次点击
    这是一个创建于 1587 天前的主题,其中的信息可能已经有所发展或是发生改变。
    参照的教程是这个:

    https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

    主要问的不是算法本身..是 pytorch 的一些使用的问题还是很懵
    我没有完全按照上面链接的方法写.我理解的思路就是用预训练的 vgg 网络的卷积部分去提取图像的“特征”,然后转化成 gram 矩阵之后再用 MSE_LOSS 和输入的图像进行比较.

    所以我写的时候是这样写的:

    optimizer = optim.LBFGS([content_img.requires_grad_().to(device)])

    input_styles = model(content_img)
    for input,targetin zip(input_styles,target_styles):
    loss += F.mse_loss(gram(input),gram(target))
    loss.backward()

    model 就是 5 个 VGG19 的卷积+ReLu 的部分,返回的是不同深度的特征提取张亮
    gram 是转换成 gram 矩阵的函数
    优化器优化输入图像

    但是我发现当我迭代的时候 backward 完全没有传递到给 content_img
    反而当我把 gram 这个函数去掉了之后,如下:

    loss += F.mse_loss(input,target)
    优化器就发挥作用了
    这个部分按道理应该是负责 content_loss 的部分的,也就是没有经过 Gram 矩阵直接比较

    但为什么上面加了 gram 就不起作用,下面的就可以起作用呢?

    我看案例中是直接把 loss 作为一个层放进模型中的,然后发现优化器就可以正常操作了.

    还是不太理解 pytorch 中梯度到底是怎么传递的,我找了一些测试如:

    x = torch.rand(1,2,3,4)
    y = torch.rand(1,2,3,4).requires_grad_()
    optimizer = optim.LBFGS([y])

    now = [0]
    total = 200
    while now[0] < total:
    def closure():
    optimizer.zero_grad()
    z = F.mse_loss(gram(x),gram(y))*100
    z.backward()
    now[0] += 1
    print(y)
    return z
    optimizer.step(closure)

    这种情况下是可以发挥作用的,但是当我把 loss 通过 vgg 产生当不同结果累积起来之后好像 backward(还是优化器)就没有起作用了

    虽然按照最上面链接的教程是可以跑的,但是还是想知道这么写的错误在哪里呢?


    附 gram 函数:
    def gram(input):
    (bs, ch, h, w) = input.size()
    features = input.view(bs * ch, w * h)
    Gram = torch.mm(features,features.t())
    Gram = Gram.div(bs*ch*h*w)
    return Gram


    有一些缩进没有写好...迷迷糊糊的复制粘贴进来也没有调整排版
    麻烦各位大佬凑合看一看,,研究了一整天现在半夜已经精神恍惚了..请见谅
    1 条回复    2020-08-07 19:40:49 +08:00
    Waihinchan
        1
    Waihinchan  
    OP
       2020-08-07 19:40:49 +08:00
    今天又测试了一下..这个写法没有问题..是因为损失太低了,打印的时候没有转换打印格式所以一直显示数值没有变化嗯...
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   943 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 24ms · UTC 22:58 · PVG 06:58 · LAX 14:58 · JFK 17:58
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.