快捷方式

OpenMLEnv

torchrl.envs.OpenMLEnv(*args, **kwargs)[source]

一个用于 OpenML 数据的环境接口,可在 bandit 上下文中进行使用。

Doc: https://www.openml.org/search?type=data

Scikit-learn 接口: https://scikit-learn.cn/stable/modules/generated/sklearn.datasets.fetch_openml.html

参数:
  • dataset_name (str) – 支持以下数据集: "adult_num""adult_onehot""mushroom_num""mushroom_onehot""covertype""shuttle""magic"

  • device (torch.device兼容可选) – 预期输入和输出数据的设备。默认为 "cpu"

  • batch_size (torch.Size兼容可选) – 环境的批次大小,即调用 reset() 时采样的元素数量。默认为空批次大小,即一次采样一个元素。

变量:

available_envs (List[str]) – 由此类构建的环境列表。

示例

>>> env = OpenMLEnv("adult_onehot", batch_size=[2, 3])
>>> print(env.reset())
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3, 106]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=cpu,
    is_shared=False)

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

为初学者和高级开发者提供深入的教程

查看教程

资源

查找开发资源并让您的问题得到解答

查看资源