TensorFlow CNN 形状不匹配

     2023-03-29     238

关键词:

【中文标题】TensorFlow CNN 形状不匹配【英文标题】:Tensorflow CNN shape mismatch 【发布时间】:2020-10-05 07:40:05 【问题描述】:
def load_data(data_path, batch_size, num_workers=2):
    t_m  = transforms.Compose(
            [transforms.Grayscale(num_output_channels=1),
             transforms.Resize((400,400)),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

            ])

    dataset = torchvision.datasets.ImageFolder(root = data_path, transform=t_m)
    # print (np.shape(dataset))
    #split
    train, test = torch.utils.data.random_split(dataset, [int( len(dataset) * 0.7 ),  len(dataset) - int( len(dataset) * 0.7 ) ])


    trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers,drop_last = True)
    testloader = torch.utils.data.DataLoader(test, batch_size=batch_size,
                                         shuffle=False, num_workers=num_workers, drop_last = False)


    return dataset,trainloader,testloader





import torch.nn as nn
model = torch.nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=5, padding=2),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=5, padding=2),
    nn.MaxPool2d(2, 2),
    nn.Linear ( 7 * 7 * 64, 1000),
    nn.Linear(1000, 600),
    nn.Linear(600, 200),
    nn.Linear(200, 10)


)



#Training

total_epochs = 5
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())

for epoch in tqdm(range(total_epochs)):

  #initialize 
  batch_count = 0
  gc.collect()
  loop_loss = 0.0

  for img in (trainloader):

      input_, label_ = img

      # print (input_.shape)

      out = model(input_)
      out= nn.functional.relu(out)



      loss = criterion(out, label_)
      loss.backward()

      optimizer.zero_grad()
      optimizer.step()

      loop_loss = loop_loss + loss.item()
      batch_count = batch_count + 1

      print('batch_loss: ', str(loss.item()))

  print('Epochs completed:', epoch+1,'\n')
  print('epoch_loss = ' + loop_loss/float(batch_count))

大小不匹配,m1:[25600 x 100],m2:[3136 x 1000] 在 /pytorch/aten/src/TH/generic/THTensorMath.cpp:41

请解释形状哪里出错了?我应该如何解决这个问题? 我是新手,所以这可能不是一个好问题,但任何细节都会有所帮助 输入图像大小调整为 400,400 并从 rgb 转换为灰度

【问题讨论】:

【参考方案1】:

您的问题出在第一个线性层。总是这样编码,以便您自己弄清楚。

class MyModel(nn.Module):
    def __init__(self, params):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(...)
        self.fc = nn.Linear(...)
    def forward(self, x):
        x = self.conv1(x)
        import pdb; pdb.set_trace()
        x = self.fc(x)
        return x

这样你可以把 pdb 放在你想要的地方,你可以使用 x.shape 命令检查形状。您的问题在于 conv 层的输出形状与您的第一个 Linear 层之间的不匹配。

【讨论】:

作为参数传递什么? @AtqaAmir 看看这个:github.com/pytorch/vision/blob/master/torchvision/models/…,参数可以是类似 num_classes 的东西。如果您以这种方式执行,在 forward 方法中,您可以放置​​ pdb 跟踪器并检查张量的形状。 @AtqaAmir,也不要忘记将问题编辑为 PyTorch 形状。

TensorFlow 标签号与轴上的形状不匹配

】TensorFlow标签号与轴上的形状不匹配【英文标题】:TensorFlowlabelnumberisamismatchwiththeshapeontheaxis【发布时间】:2020-05-1802:00:03【问题描述】:尝试运行代码实验室:https://codelabs.developers.google.com/codelabs/recognize-flowers-with-tensorflow-on-an... 查看详情

根据输入形状的计算是不是存在差异? (带有 Tensorflow 的 Python 中的 CNN)

】根据输入形状的计算是不是存在差异?(带有Tensorflow的Python中的CNN)【英文标题】:Isthereadifferenceincomputationaccordingtoinputshape?(CNNinPythonwithTensorflow)根据输入形状的计算是否存在差异?(带有Tensorflow的Python中的CNN)【发布时间... 查看详情

如何在tensorflow中向cnn输入多个序列?(代码片段)

我是TensorFlow的新手,并通过一个类似于文本分类指南here的示例我有一个单一序列的工作模型,但我想弄清楚如何让每个观察的两个不同的序列进入模型训练。我可以连接两个序列并将它们作为一个输入,但我希望每个序列都是... 查看详情

Keras:密集层和激活层之间的形状不匹配

...层之间的形状不匹配。我错过了一些明显的东西吗?使用TensorFlow后端。print(x_train.shape)print(y_trai 查看详情

字节缓冲区的大小和形状不匹配(以前没有回答)

...:2021-08-1515:13:01【问题描述】:我正在尝试创建一个使用tensorflow模型的应用程序。当inputFeature0.loadBuffer(byteBuffer)被执行时,我的应用程序崩溃了。(通过评论了解)varimg 查看详情

无法在 Tensorflow 中输入形状的值

】无法在Tensorflow中输入形状的值【英文标题】:CannotfeedvalueofshapeinTensorflow【发布时间】:2020-09-0913:23:06【问题描述】:我正在尝试使用指定的确切代码按照教程训练CNN模型,但收到错误消息。下面是我正在使用的代码。#Commonimp... 查看详情

Logits 和 Labels 必须具有相同的形状:Tensorflow

】Logits和Labels必须具有相同的形状:Tensorflow【英文标题】:LogitsandLabelsmusthavethesameshape:Tensorflow【发布时间】:2021-09-1601:30:40【问题描述】:我正在尝试使用CNN网络对猫与狗进行分类,但是尽管检查了两次,但我无法找到错误的... 查看详情

TensorFlow:ValueError:形状不兼容

】TensorFlow:ValueError:形状不兼容【英文标题】:TensorFlow:ValueError:Shapesareincompatible【发布时间】:2022-01-0304:24:28【问题描述】:我在编码器-解码器模型的数据形状方面遇到了一些困难。问题似乎与Dense层有关,但我无法弄清楚为... 查看详情

为啥我的 Keras TimeDistributed CNN + LSTM 模型期望形状不完整

】为啥我的KerasTimeDistributedCNN+LSTM模型期望形状不完整【英文标题】:WhydoesmyKerasTimeDistributedCNN+LSTMmodelexpectanincompleteshape为什么我的KerasTimeDistributedCNN+LSTM模型期望形状不完整【发布时间】:2021-10-2504:09:22【问题描述】:我正在Keras... 查看详情

使用 TensorFlow CNN 进行图像分类

】使用TensorFlowCNN进行图像分类【英文标题】:ImageclassificationwithTensorFlowCNN【发布时间】:2016-04-1903:38:39【问题描述】:我对使用CNN进行图像分类非常陌生,我遵循了谷歌CNN教程:https://www.tensorflow.org/tutorials/images/cnn这很有效,但... 查看详情

Fashion Mnist TensorFlow 数据形状不兼容

】FashionMnistTensorFlow数据形状不兼容【英文标题】:FashionMnistTensorflowDataShapeIncompatibility【发布时间】:2021-10-1912:27:18【问题描述】:我知道有类似的问题。虽然我已经检查过了,但我没有解决我的问题。我尝试在时尚Mnist数据集... 查看详情

tensorflow实战-tensorflow实现卷积神经网络cnn-第5章

第5章-TensorFlow实现卷积神经网络CNN5.1卷积神经网络简介卷积神经网络CNN最初是为了解决图像识别等问题设计的,当然现在的应用已经不限于图像和视频,也可以用于时间序列信号,比如音频信号、文本数据等。在深度学习出现之... 查看详情

Tensorflow Keras 修改 Iris 示例时形状不兼容

】TensorflowKeras修改Iris示例时形状不兼容【英文标题】:TensorflowKerasIncompatibleshapeswhenmodifyingIrisexample【发布时间】:2020-10-1609:51:45【问题描述】:我正在尝试使用我自己的数据代替虹膜数据集来实现基于https://janakiev.com/notebooks/kera... 查看详情

尝试使用 Tensorflow 理解用于 NLP 教程的 CNN

】尝试使用Tensorflow理解用于NLP教程的CNN【英文标题】:TryingtounderstandCNNsforNLPtutorialusingTensorflow【发布时间】:2017-05-3015:01:04【问题描述】:我关注thistutorial是为了了解NLP中的CNN。尽管我面前有代码,但仍有一些我不明白的事情... 查看详情

具有动态输入形状的 CNN

】具有动态输入形状的CNN【英文标题】:CNNwithadynamicinputshape【发布时间】:2019-04-1910:28:31【问题描述】:大家好!由于我正在尝试制作一个将灰度图像转换为rgb图像的全卷积神经网络,因此我想知道是否可以在不同尺寸的图像... 查看详情

Tensorflow ValueError:层“顺序”的输入0与层不兼容:预期形状=(无,20,20,3),找到形状=(无,20,3)

】TensorflowValueError:层“顺序”的输入0与层不兼容:预期形状=(无,20,20,3),找到形状=(无,20,3)【英文标题】:TensorflowValueError:Input0oflayer"sequential"isincompatiblewiththelayer:expectedshape=(None,20,20,3),foundshape=(None,20,3)【发布时... 查看详情

InvalidArgumentError:找到 2 个根错误。 Tensorflow 文本分类模型中的不兼容形状

】InvalidArgumentError:找到2个根错误。Tensorflow文本分类模型中的不兼容形状【英文标题】:InvalidArgumentError:2rooterror(s)found.IncompatibleshapesinTensorflowtext-classificationmodel【发布时间】:2020-05-0902:20:46【问题描述】:我正在尝试从以下repo... 查看详情

形状不匹配:索引数组无法与形状一起广播

】形状不匹配:索引数组无法与形状一起广播【英文标题】:shapemismatch:indexingarrayscouldnotbebroadcasttogetherwithshapes【发布时间】:2018-02-1720:26:40【问题描述】:j=np.arange(20,dtype=np.int)site=np.ones((20,200),dtype=np.int)sumkma=np.ones((100,20))[sumkma... 查看详情