diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 3d341a5215..2e6f79658a 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -1,24 +1,20 @@ import torch -from typing import ClassVar +from typing import ClassVar, Final from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op from tinygrad.helpers import getenv specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.CMPEQ: lambda x,y: (x==y).float(), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), - MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg]) + MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg]), + ProcessingOps.CONV: lambda x,w,C: C.px == C.px_ and C.py == C.py_ and torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)) }) device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) class TorchBuffer(GenericExecAST): fxn_for_op : ClassVar = specialized_fxn_for_op + SUPPORTS_SIMPLE_PADDING : Final = True @staticmethod def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device)) def toCPU(x): return x.buf.cpu().numpy() - - SUPPORTS_SIMPLE_PADDING = True - def processing_op(x,op,w,C): - assert op == ProcessingOps.CONV, f"{op} isn't supported" - assert C.px == C.px_ and C.py == C.py_, "asymmetric padding in conv is not supported" - return TorchBuffer(torch.conv2d(x.buf, w.buf, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)))