mps option in torch (note: it's broken)

This commit is contained in:
George Hotz
2023-01-25 10:10:31 -08:00
parent 66da3bc3c0
commit 0d594ccc51

View File

@@ -1,8 +1,8 @@
import torch
import os, torch
from tinygrad.llops.ops_cpu import CPUBuffer # type: ignore
from tinygrad.ops import ProcessingOps, GenericExecAST
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if int(os.getenv("MPS", "0")) else "cpu"))
class TorchBuffer(torch.Tensor, GenericExecAST):
def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
def strided(x, arg): return x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])