The following code uses the CUDA runtime functions:
// allocate host memory: float *mem_host1 = (float *)malloc(sizeof(float) * SIZE[0] * SIZE[1] * SIZE[2]); float *mem_host2 = (float *)malloc(sizeof(float) * SIZE[0] * SIZE[1] * SIZE[2]); if((mem_host1 == 0) || (mem_host2 == 0)) { cerr << "out of memory\n"; exit(1); } // init host memory: init(mem_host1); // allocate device memory: cudaExtent extent; extent.width = SIZE[0]; extent.height = SIZE[1]; extent.depth = SIZE[2]; cudaPitchedPtr mem_device; CUDA_CHECK(cudaMalloc3D(&mem_device, extent)); // copy from host memory to device memory: cudaMemcpy3DParms p = { 0 }; p.srcPtr.ptr = mem_host1; p.srcPtr.pitch = SIZE[0] * sizeof(float); p.srcPtr.xsize = SIZE[0]; p.srcPtr.ysize = SIZE[1]; p.dstPtr.ptr = mem_device.ptr; p.dstPtr.pitch = mem_device.pitch; p.dstPtr.xsize = SIZE[0]; p.dstPtr.ysize = SIZE[1]; p.extent.width = SIZE[0] * sizeof(float); p.extent.height = SIZE[1]; p.extent.depth = SIZE[2]; p.kind = cudaMemcpyHostToDevice; CUDA_CHECK(cudaMemcpy3D(&p)); // copy from device memory to host memory: p.srcPtr.ptr = mem_device.ptr; p.srcPtr.pitch = mem_device.pitch; p.dstPtr.ptr = mem_host2; p.dstPtr.pitch = SIZE[0] * sizeof(float); p.kind = cudaMemcpyDeviceToHost; CUDA_CHECK(cudaMemcpy3D(&p)); // verify host memory: verify(mem_host2); // free memory: CUDA_CHECK(cudaFree(mem_device.ptr)); free(mem_host2); free(mem_host1);
The CUDA templates equivalent of this code is given below. Every operation is performed by a single statement. Resource deallocation is implicit.
try { // allocate host memory: Cuda::HostMemoryHeap3D<float> mem_host1(SIZE[0], SIZE[1], SIZE[2]); Cuda::HostMemoryHeap3D<float> mem_host2(SIZE[0], SIZE[1], SIZE[2]); // init host memory: init(mem_host1.getBuffer()); // allocate device memory: Cuda::DeviceMemoryPitched3D<float> mem_device(SIZE[0], SIZE[1], SIZE[2]); // copy from host memory to device memory: copy(mem_device, mem_host1); // copy from device memory to host memory: copy(mem_host2, mem_device); // verify host memory: verify(mem_host2.getBuffer()); } catch(const exception &e) { cerr << e.what(); }