在conda里使用pytorch通过AMD Radeon Pro 5500M 8 GB 加速
- 创建新的conda环境,指定Python 3.10:
conda create -n pytorch_amd python=3.10
conda activate pytorch_amd
- 安装支持MPS(Metal Performance Shaders)后端的PyTorch最新稳定版本:
conda install pytorch torchvision torchaudio -c pytorch
- 验证安装并检查MPS可用性:
import torch print(torch.__version__) print(torch.backends.mps.is_available())
- 在代码中使用MPS设备:
device = torch.device("mps")
print(device)
float_points = torch.zeros(10, 2).float()
points_gpu = float_points.to(device='mps')
print(points_gpu)
points_gpu = points_gpu + 4
print(points_gpu) # points_gpu
points_gpu.zero_()
print(points_gpu)
a = torch.tensor(list(range(9))).to(device='mps')
print(a)
- 如果遇到内存问题,可以尝试设置环境变量:
export PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.7
注意事项:
- 使用Python 3.10可能会提高与PyTorch的兼容性。
- MPS后端主要针对Apple Silicon优化,对AMD GPU的支持可能不完善。
- 某些复杂操作可能不支持或性能较差。
- 性能提升可能不如NVIDIA CUDA显著。
- 对于大型模型,可能还是需要依赖CPU。
如果遇到问题:
- 检查是否有最新的PyTorch版本更新。
- 尝试使用较小的批量大小或模型。
- 考虑回退到CPU模式。
总的来说,虽然技术上可行,但在这种配置下使用PyTorch进行GPU加速可能会遇到一些挑战。如果遇到持续问题,可以考虑使用CPU模式或寻求PyTorch社区的支持。