mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
less lines for torch
This commit is contained in:
@@ -1,24 +1,20 @@
|
||||
import torch
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Final
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
|
||||
UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg])
|
||||
MovementOps.STRIDED: lambda x, arg: x.contiguous().as_strided([y[0] for y in arg], [y[1] for y in arg]),
|
||||
ProcessingOps.CONV: lambda x,w,C: C.px == C.px_ and C.py == C.py_ and torch.conv2d(x, w, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
|
||||
})
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
class TorchBuffer(GenericExecAST):
|
||||
fxn_for_op : ClassVar = specialized_fxn_for_op
|
||||
SUPPORTS_SIMPLE_PADDING : Final = True
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device))
|
||||
def toCPU(x): return x.buf.cpu().numpy()
|
||||
|
||||
SUPPORTS_SIMPLE_PADDING = True
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
assert C.px == C.px_ and C.py == C.py_, "asymmetric padding in conv is not supported"
|
||||
return TorchBuffer(torch.conv2d(x.buf, w.buf, stride=(C.sy, C.sx), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)))
|
||||
|
||||
Reference in New Issue
Block a user