mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
mps option in torch (note: it's broken)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import os, torch
|
||||
from tinygrad.llops.ops_cpu import CPUBuffer # type: ignore
|
||||
from tinygrad.ops import ProcessingOps, GenericExecAST
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if int(os.getenv("MPS", "0")) else "cpu"))
|
||||
class TorchBuffer(torch.Tensor, GenericExecAST):
|
||||
def pad(x, padding): return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
def strided(x, arg): return x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])
|
||||
|
||||
Reference in New Issue
Block a user