欢迎来到 TensorDict 文档!¶
TensorDict 是一个字典式类,它继承了张量(tensor)的属性,例如索引、形状操作、设备转换等。
您可以直接从 PyPI 安装 tensordict(有关安装说明,请参见下文的专用部分)。
$ pip install tensordict
TensorDict 的主要目的是通过抽象定制化操作来使代码库更易读和模块化。
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
通过这种级别的抽象,可以回收一个训练循环用于高度异构的任务。训练循环的每个单独步骤(数据收集和转换、模型预测、损失计算等)都可以针对当前用例进行定制,而不会影响其他步骤。例如,上述示例可以轻松地用于分类和分割任务,以及许多其他任务。
安装¶
TensorDict 的发布与 PyTorch 同步,因此请务必使用最新版本的 PyTorch 来享受该库的最新功能(尽管核心功能保证向后兼容 pytorch>=1.13)。可以通过以下方式安装 nightly 版本:
$ pip install tensordict-nightly
或者,如果您愿意为该库做出贡献,可以通过 git clone 进行安装。
$ cd path/to/root
$ git clone https://github.com/pytorch/tensordict
$ cd tensordict
$ pip install -e .
教程¶
基础知识¶
tensordict.nn¶
数据加载¶
目录¶
- 概述
- 在分布式设置中使用 TensorDict
- 追踪 TensorDictModule
- 保存 TensorDict 和 tensorclass 对象
- API 参考
- tensordict 包
- tensordict.nn 包
- TensorDictModuleBase
- TensorDictModule
- ProbabilisticTensorDictModule
- ProbabilisticTensorDictSequential
- TensorDictSequential
- TensorDictModuleWrapper
- CudaGraphModule
- WrapModule
- InteractionType
- set_interaction_type
- set_composite_lp_aggregate
- composite_lp_aggregate
- as_tensordict_module
- 集成
- 编译 TensorDictModules
- 分布
- 工具
- tensorclass