less lines for torch

This commit is contained in:
George Hotz
2023-02-08 18:15:59 -06:00
parent 58a03eb693
commit c642f5e72b

View File

@@ -1,24 +1,20 @@
import torch
from typing import ClassVar
from typing import ClassVar, Final
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op
from tinygrad.helpers import getenv
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])
MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg]),
ProcessingOps.CONV: lambda x,w,C: C.px == C.px_ and C.py == C.py_ and torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
})
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
class TorchBuffer(GenericExecAST):
fxn_for_op : ClassVar = specialized_fxn_for_op
SUPPORTS_SIMPLE_PADDING : Final = True
@staticmethod
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device))
def toCPU(x): return x.buf.cpu().numpy()
SUPPORTS_SIMPLE_PADDING = True
def processing_op(x,op,w,C):
assert op == ProcessingOps.CONV, f"{op} isn't supported"
assert C.px == C.px_ and C.py == C.py_, "asymmetric padding in conv is not supported"
return TorchBuffer(torch.conv2d(x.buf, w.buf, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)))