欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 建筑 > pytorch dataloader修改shape的问题

pytorch dataloader修改shape的问题

2024/10/24 19:28:56 来源:https://blog.csdn.net/gychixxx/article/details/141278260  浏览:    关键词:pytorch dataloader修改shape的问题

不管是list类型的数据还是dict类型的数据,输入给dataloader都会把我的数据shape给转置了。

比如我的数据应该是[batch_size, seq_len]尺寸的,dataloader吐出来的却是[seq_len, batch_size]。如果还是tensor类型还行,但是它会变成List[Tensor]类型。如果只是两维也还行,但是如果是[batch_size, seq_len, nvars]呢,很难把shape变回来。

之前其实没有遇到过这个问题。通过查看源码,加上和之前的代码对比,终于找到了原因:

    elif isinstance(elem, collections.abc.Sequence):# check to make sure that the elements in batch have consistent sizeit = iter(batch)elem_size = len(next(it))if not all(len(elem) == elem_size for elem in it):raise RuntimeError('each element in list of batch should be of equal size')transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.

dataloader默认的collate_fn是这样处理的。也就是说只要是collections.abs.Sequence类型的,都会被转置。

那么解决办法就明显了,只要数据不是这个类型的就行。

所以在准备dataset的时候要把数据转成numpy或者tensor格式,这样就不会走转置这个路径了。

如果用的是huggingface的datasets,可以用dataset.set_format(type='numpy')来设置数据格式,且它只对数字类型的起作用,不会影响文本类型,可以放心使用。

实在搞不懂collate默认转置的意义是什么。开发者的解释在这里default_collate: if elem in batch is list, the whole batch would be transpose? · Issue #50272 · pytorch/pytorch (github.com)

感兴趣的可以看下。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com