device_guard¶
注解
该 API 仅支持静态图模式。
一个用于指定 OP 运行设备的上下文管理器。
参数¶
device (str|None) – 指定上下文中使用的设备。它可以是
cpu
、gpu
、gpu:x
,其中x
是 GPU 的编号。当它被设置为cpu
或者gpu
时,创建在该上下文中的 OP 将被运行在 CPUPlace 或者 CUDAPlace 上。若设置为gpu
,同时程序运行在单卡模式下,设备的索引将与执行器的设备索引保持一致,默认值:None,在该上下文中的 OP 将被自动地分配设备。
代码示例¶
>>> import paddle
>>> paddle.device.set_device('gpu')
>>> paddle.enable_static()
>>> support_gpu = paddle.is_compiled_with_cuda()
>>> place = paddle.CPUPlace()
>>> if support_gpu:
... place = paddle.CUDAPlace(0)
>>> # if GPU is supported, the three OPs below will be automatically assigned to CUDAPlace(0)
>>> data1 = paddle.full(shape=[1, 3, 8, 8], fill_value=0.5, dtype='float32')
>>> data2 = paddle.full(shape=[1, 3, 64], fill_value=0.5, dtype='float32')
>>> shape = paddle.shape(data2)
>>> with paddle.static.device_guard("cpu"):
... # Ops created here will be placed on CPUPlace
... shape = paddle.slice(shape, axes=[0], starts=[0], ends=[4])
>>> with paddle.static.device_guard('gpu'):
... # if GPU is supported, OPs created here will be placed on CUDAPlace(0), otherwise on CPUPlace
... out = paddle.reshape(data1, shape=shape)
>>> exe = paddle.static.Executor(place)
>>> exe.run(paddle.static.default_startup_program())
>>> result = exe.run(fetch_list=[out])