快捷方式

TokenizedDatasetLoader

class torchrl.data.TokenizedDatasetLoader(split, max_length, dataset_name, tokenizer_fn: type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, valid_size: int = 2000, num_workers: int | None = None, tokenizer_class=None, tokenizer_model_name=None)[source]

加载已分词的数据集,并缓存其内存映射副本。

参数:
  • split (str) – "train" 或 "valid" 中的一个。

  • max_length (int) – 最大序列长度。

  • dataset_name (str) – 数据集的名称。

  • tokenizer_fn (callable) – 分词方法构造器,例如 torchrl.data.llm.TensorDictTokenizer。调用时,它应该返回一个 tensordict.TensorDict 实例或一个类字典结构,其中包含分词后的数据。

  • pre_tokenization_hook (callable, optional) – 在分词之前对 Dataset 调用。它应该返回一个修改后的 Dataset 对象。其预期用途是执行需要修改整个 Dataset 的任务,而不是修改单个数据点,例如根据特定条件丢弃某些数据点。分词和其他对数据的“逐元素”操作由映射到 Dataset 的 process 函数执行。

  • root_dir (path, optional) – 存储数据集的路径。默认为 "$HOME/.cache/torchrl/data"

  • from_disk (bool, optional) – 如果为 True,将使用 datasets.load_from_disk()。否则,将使用 datasets.load_dataset()。默认为 False

  • valid_size (int, optional) – 验证数据集的大小(如果 split 以 "valid" 开头)将被截断到此值。默认为 2000 个项目。

  • num_workers (int, optional) – 分词过程中调用的 datasets.dataset.map() 的工作进程数。默认为 max(os.cpu_count() // 2, 1)

  • tokenizer_class (Type, optional) – 分词器类,例如 AutoTokenizer (默认)。

  • tokenizer_model_name (str, optional) – 应从中收集词汇表的模型。默认为 "gpt2"

数据集将存储在 <root_dir>/<split>/<max_length>/ 中。

示例

>>> from torchrl.data.llm import TensorDictTokenizer
>>> from torchrl.data.llm.reward import  pre_tokenization_hook
>>> split = "train"
>>> max_length = 550
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
>>> loader = TokenizedDatasetLoader(
...     split,
...     max_length,
...     dataset_name,
...     TensorDictTokenizer,
...     pre_tokenization_hook=pre_tokenization_hook,
... )
>>> dataset = loader.load()
>>> print(dataset)
TensorDict(
    fields={
        attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([185068]),
    device=None,
    is_shared=False)
static dataset_to_tensordict(dataset: datasets.Dataset | TensorDict, data_dir: Path, prefix: NestedKey = None, features: Sequence[str] = None, batch_dims=1, valid_mask_key=None)[source]

将数据集转换为内存映射的 TensorDict。

如果数据集已经是 TensorDict 实例,则只需将其转换为内存映射的 TensorDict。否则,数据集预计具有一个 features 属性,该属性是一个字符串序列,指示数据集中可以找到的功能。如果没有,则必须将 features 显式传递给此函数。

参数:
  • dataset (datasets.Dataset, TensorDict等效) – 要转换为内存映射 TensorDict 的数据集。如果 featuresNone,则必须有一个 features 属性,其中包含要写入 tensordict 的键列表。

  • data_dir (Path等效) – 应将数据写入的目录。

  • prefix (NestedKey, optional) – 数据集位置的前缀。可用于区分经过不同预处理的同一数据集的多个副本。

  • features (str 的序列, optional) – 一个字符串序列,指示数据集中可以找到的功能。

  • batch_dims (int, optional) – 数据的 batch_dimensions 数量(即 TensorDict 可以索引的维度)。默认为 1。

  • valid_mask_key (NestedKey, optional) – 如果提供,将尝试收集此条目并用于过滤数据。默认为 None(即,没有过滤器键)。

返回: 一个包含数据集内存映射张量的 TensorDict。

示例

>>> from datasets import Dataset
>>> import tempfile
>>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)})
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     data_memmap = TokenizedDatasetLoader.dataset_to_tensordict(
...         data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"]
...     )
...     print(data_memmap)
TensorDict(
    fields={
        some: TensorDict(
            fields={
                prefix: TensorDict(
                    fields={
                        labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                        tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([10]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
load()[source]

加载预处理的内存映射数据集(如果存在),否则创建它。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源