mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
No ctx in llops (#345)
* remove ctx from gpu ops * ctx for the others * this is okay * mlops are not static. fix lazy * cl is property, _processing_op is class method * kernel_name * contiguous_op
This commit is contained in:
@@ -2,6 +2,9 @@ from __future__ import annotations
|
||||
from typing import Union, NamedTuple, List, Any, Tuple, Dict
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
import functools, operator
|
||||
from tinygrad.helpers import prod
|
||||
import sys
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, MovementOps, ProcessingOps, log_op, DEBUG, GRAPH
|
||||
from enum import Enum
|
||||
@@ -70,6 +73,36 @@ class LazyBuffer:
|
||||
def toCPU(self):
|
||||
return self.realize().toCPU()
|
||||
|
||||
def unary_op(x, op): return elementwise_op(op, (x,))
|
||||
def binary_op(x, op, y:LazyBuffer): return elementwise_op(op, (x,y))
|
||||
def contiguous_op(x): return x if x.st.contiguous else LazyBuffer(x.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (x,)))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def movement_op(x, op:MovementOps, arg) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and x.optype == BinaryOps:
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer): return y.movement_op(op, arg)
|
||||
return elementwise_op(y.op, tuple(replace_with_movement_op(z) for z in y.src))
|
||||
return replace_with_movement_op(x.op)
|
||||
|
||||
# if a MovementOp is applied to a MovementOp, merge them and use one buffer
|
||||
ret = LazyBuffer(x.st, MovementOps, LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg))
|
||||
ret.shape = ret.st.movement_op(op, arg).shape # update the shape after we modify the ShapeTracker
|
||||
|
||||
if REMOVE_MOVEMENT_NOPS and x.optype == MovementOps and x.realized is None and ret.st.contiguous:
|
||||
root = get_root(x.op)
|
||||
if ret.st.shape == root.shape:
|
||||
return root
|
||||
|
||||
return ret
|
||||
|
||||
def reduce_op(x, op, new_shape:Tuple[int]):
|
||||
return LazyBuffer(new_shape, ReduceOps, LazyOp(op, (x,), new_shape))
|
||||
|
||||
def processing_op(x, op, w:LazyBuffer, C):
|
||||
return LazyBuffer(C.out_shape, ProcessingOps, LazyOp(op, (x.contiguous_op(), w.contiguous_op()), C))
|
||||
|
||||
def ast_op(op: Op, srcs_code: List[str]) -> str:
|
||||
code = gops.code_for_op[op]
|
||||
if len(srcs_code) >= 1: code = code.replace("A", srcs_code[0])
|
||||
@@ -117,37 +150,6 @@ def elementwise_op(op, srcs:Tuple[LazyBuffer]) -> LazyBuffer:
|
||||
|
||||
return LazyBuffer(out_shape, BinaryOps, LazyOp(op, srcs))
|
||||
|
||||
def unary_op(op, x): return elementwise_op(op, (x,))
|
||||
def binary_op(op, x, y): return elementwise_op(op, (x,y))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def movement_op(op:MovementOps, x:LazyBuffer, arg):
|
||||
if SHUFFLE_MOVEMENT_OPS and x.optype == BinaryOps:
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer): return movement_op(op, y, arg)
|
||||
return elementwise_op(y.op, tuple(replace_with_movement_op(z) for z in y.src))
|
||||
return replace_with_movement_op(x.op)
|
||||
|
||||
# if a MovementOp is applied to a MovementOp, merge them and use one buffer
|
||||
ret = LazyBuffer(x.st, MovementOps, LazyOp(op, (x.op if MERGE_MOVEMENT_OPS and x.optype == MovementOps and x.realized is None else x,), arg))
|
||||
ret.shape = ret.st.movement_op(op, arg).shape # update the shape after we modify the ShapeTracker
|
||||
|
||||
if REMOVE_MOVEMENT_NOPS and x.optype == MovementOps and x.realized is None and ret.st.contiguous:
|
||||
root = get_root(x.op)
|
||||
if ret.st.shape == root.shape:
|
||||
return root
|
||||
|
||||
return ret
|
||||
|
||||
def reduce_op(op, x, new_shape):
|
||||
return LazyBuffer(new_shape, ReduceOps, LazyOp(op, (x,), new_shape))
|
||||
|
||||
def processing_op(op, x, w, C):
|
||||
if not x.st.contiguous: x = LazyBuffer(x.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (x,)))
|
||||
if not w.st.contiguous: w = LazyBuffer(w.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (w,)))
|
||||
return LazyBuffer(C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
|
||||
|
||||
# these functions determines the backing buffer
|
||||
import tinygrad.llops.ops_gpu as gops
|
||||
@@ -181,7 +183,7 @@ def _realize_binary_op(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBu
|
||||
real_dict[s] = f"arg_{len(real_srcs)}"
|
||||
real_srcs.append((f"arg_{len(real_srcs)}", s.realize()))
|
||||
code = ast(self.op, real_dict)
|
||||
return gops._processing_op(self.shape, real_srcs, code, arg), [x[1] for x in real_srcs]
|
||||
return gops.GPUBuffer(self.shape)._processing_op(real_srcs, code, arg), [x[1] for x in real_srcs]
|
||||
|
||||
def _realize(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBuffer]]:
|
||||
if self.optype == LoadOps and self.op.op == LoadOps.FROMCPU:
|
||||
@@ -189,10 +191,10 @@ def _realize(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBuffer]]:
|
||||
return gops.GPUBuffer.fromCPU(self.op.arg), []
|
||||
elif self.optype == LoadOps and self.op.op == LoadOps.CONTIGUOUS:
|
||||
real_src = self.op.src[0].realize()
|
||||
return gops.contiguous(real_src), [real_src]
|
||||
return real_src.contiguous(), [real_src]
|
||||
elif self.optype == ReduceOps:
|
||||
real_src = self.op.src[0].realize()
|
||||
return gops.reduce_op(self.op.op, real_src, self.op.arg), [real_src]
|
||||
return real_src.reduce_op(self.op.op, self.op.arg), [real_src]
|
||||
elif self.optype == MovementOps:
|
||||
real_src = get_root(self.op).realize()
|
||||
return gops.GPUBuffer(self.st, real_src), [real_src]
|
||||
|
||||
@@ -5,7 +5,7 @@ def prod(x): return math.prod(x)
|
||||
def reduce_shape(shape, axis):
|
||||
return [1 if i in axis else shape[i] for i in range(len(shape))]
|
||||
|
||||
conv_args = namedtuple('conv_args', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs', 'cout', 'py', 'px', 'dy', 'dx', 'out_shape'])
|
||||
ConvArgs = namedtuple('ConvArgs', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'ys', 'xs', 'bs', 'cout', 'py', 'px', 'dy', 'dx', 'out_shape'])
|
||||
def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1):
|
||||
# TODO: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout
|
||||
cout,cin,H,W = w_shape
|
||||
@@ -19,4 +19,4 @@ def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1):
|
||||
ox = (ix + 2*px - dx * (W-1) - 1)//xs + 1
|
||||
if cin*groups != cin_: raise Exception(f"Input Tensor shape {x_shape} does not match the shape of the weights {w_shape}. ({cin*groups} vs. {cin_})")
|
||||
assert cout % groups == 0
|
||||
return conv_args(H, W, groups, cout//groups, cin, oy, ox, iy, ix, ys, xs, bs, cout, py, px, dy, dx, (bs, cout, oy, ox))
|
||||
return ConvArgs(H, W, groups, cout//groups, cin, oy, ox, iy, ix, ys, xs, bs, cout, py, px, dy, dx, (bs, cout, oy, ox))
|
||||
|
||||
@@ -17,60 +17,60 @@ class CPUBuffer(np.ndarray):
|
||||
def fromCPU(x): return x
|
||||
def toCPU(x): return x
|
||||
|
||||
def unary_op(op, x):
|
||||
if op == UnaryOps.NOOP: return x[:]
|
||||
elif op == UnaryOps.RELU: return x.relu()
|
||||
elif op == UnaryOps.EXP: return x.exp()
|
||||
elif op == UnaryOps.LOG: return x.log()
|
||||
elif op == UnaryOps.NEG: return -x
|
||||
elif op == UnaryOps.SIGN: return x.sign()
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def binary_op(op, x, y):
|
||||
if op == BinaryOps.ADD: return x+y
|
||||
elif op == BinaryOps.SUB: return x-y
|
||||
elif op == BinaryOps.MUL: return x*y
|
||||
elif op == BinaryOps.DIV: return y/x
|
||||
elif op == BinaryOps.POW: return x**y
|
||||
elif op == BinaryOps.CMPEQ: return 1.0*(x==y)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def reduce_op(op, inp, new_shape):
|
||||
if inp.shape == new_shape: # this is just a copy, regardless of the reduce op
|
||||
return inp[:]
|
||||
else:
|
||||
if new_shape == (1,): # full reduce
|
||||
axis = tuple(range(len(inp.shape)))
|
||||
else:
|
||||
assert len(inp.shape) == len(new_shape)
|
||||
axis = tuple([i for i,(a,b) in enumerate(zip(inp.shape, new_shape)) if a != b])
|
||||
if op == ReduceOps.SUM: return inp.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: return inp.amax(axis, keepdims=True)
|
||||
def unary_op(x, op):
|
||||
if op == UnaryOps.NOOP: return x[:]
|
||||
elif op == UnaryOps.RELU: return x.relu()
|
||||
elif op == UnaryOps.EXP: return x.exp()
|
||||
elif op == UnaryOps.LOG: return x.log()
|
||||
elif op == UnaryOps.NEG: return -x
|
||||
elif op == UnaryOps.SIGN: return x.sign()
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def movement_op(op, x, arg=None):
|
||||
if op == MovementOps.RESHAPE: return x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: return x.permute(arg)
|
||||
elif op == MovementOps.FLIP: return x.flip(arg)
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
return x.custompad(padding)[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
def binary_op(x, op, y):
|
||||
if op == BinaryOps.ADD: return x+y
|
||||
elif op == BinaryOps.SUB: return x-y
|
||||
elif op == BinaryOps.MUL: return x*y
|
||||
elif op == BinaryOps.DIV: return y/x
|
||||
elif op == BinaryOps.POW: return x**y
|
||||
elif op == BinaryOps.CMPEQ: return 1.0*(x==y)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def processing_op(op,x,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
|
||||
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
tx = np.lib.stride_tricks.as_strided(gx,
|
||||
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
|
||||
strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx),
|
||||
writeable=False,
|
||||
)
|
||||
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
|
||||
tmp = np.empty((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype)
|
||||
for g in range(C.groups):
|
||||
#ijYXyx,kjyx -> iYXk ->ikYX
|
||||
tmp[:,g] = np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
|
||||
return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
|
||||
def reduce_op(x, op, new_shape):
|
||||
if x.shape == new_shape: # this is just a copy, regardless of the reduce op
|
||||
return x[:]
|
||||
else:
|
||||
if new_shape == (1,): # full reduce
|
||||
axis = tuple(range(len(x.shape)))
|
||||
else:
|
||||
assert len(x.shape) == len(new_shape)
|
||||
axis = tuple([i for i,(a,b) in enumerate(zip(x.shape, new_shape)) if a != b])
|
||||
if op == ReduceOps.SUM: return x.sum(axis, keepdims=True)
|
||||
elif op == ReduceOps.MAX: return x.amax(axis, keepdims=True)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def movement_op(x, op, arg=None):
|
||||
if op == MovementOps.RESHAPE: return x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: return x.permute(arg)
|
||||
elif op == MovementOps.FLIP: return x.flip(arg)
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
slicee = [(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)]
|
||||
return x.custompad(padding)[tuple([slice(x[0], x[1], None) for x in slicee])]
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
if C.px > 0 or C.py > 0: x = np.pad(x, [(0,0), (0,0), (C.py, C.py), (C.px, C.px)])
|
||||
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
|
||||
tx = np.lib.stride_tricks.as_strided(gx,
|
||||
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
|
||||
strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx),
|
||||
writeable=False,
|
||||
)
|
||||
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
|
||||
tmp = np.empty((C.bs,C.groups,C.oy,C.ox,C.rcout),dtype=x.dtype)
|
||||
for g in range(C.groups):
|
||||
#ijYXyx,kjyx -> iYXk ->ikYX
|
||||
tmp[:,g] = np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
|
||||
return np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox).view(CPUBuffer)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
from typing import List, Tuple
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.llops.ops_cpu import unary_op
|
||||
from typing import List, Tuple, Optional
|
||||
from tinygrad.helpers import prod, ConvArgs
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
|
||||
@@ -19,12 +19,33 @@ def require_init_gpu():
|
||||
cl_ctx = cl.Context(devices=devices)
|
||||
cl_queue = cl.CommandQueue(cl_ctx) # this is an in-order command queue
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
class CLProgram:
|
||||
def __init__(self, name, prg, options=tuple(), argdtypes=None):
|
||||
self.name = name
|
||||
self.built = cl.Program(cl_ctx, prg).build(options=options)
|
||||
self.clprg = self.built.__getattr__(name)
|
||||
if argdtypes is not None: self.clprg.set_scalar_arg_dtypes(argdtypes)
|
||||
def __call__(self, *args):
|
||||
#print(f"running {self.name} with {args[0]} count {len(args)-2}")
|
||||
self.clprg(cl_queue, *args)
|
||||
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.NEG: "(-(A))", UnaryOps.SIGN: "sign(A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(B/A)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
}
|
||||
|
||||
class GPUBuffer:
|
||||
def __init__(self, shape, hostbuf=None):
|
||||
def __init__(self, shape, hostbuf:Optional[GPUBuffer]=None):
|
||||
require_init_gpu()
|
||||
self.st = ShapeTracker(shape)
|
||||
self.shape = self.st.shape
|
||||
self.cl = hostbuf.cl if hostbuf is not None else cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*prod(self.shape))
|
||||
self._buf = hostbuf._buf if hostbuf is not None else None
|
||||
|
||||
@property
|
||||
def cl(self):
|
||||
if self._buf is None: self._buf = cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*prod(self.shape))
|
||||
return self._buf
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
@@ -38,143 +59,128 @@ class GPUBuffer:
|
||||
|
||||
def toCPU(self):
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
cl.enqueue_copy(cl_queue, data, contiguous(self).cl, is_blocking=True)
|
||||
cl.enqueue_copy(cl_queue, data, self.contiguous_op().cl, is_blocking=True)
|
||||
return data
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
class CLProgram:
|
||||
def __init__(self, name, prg, options=tuple(), argdtypes=None):
|
||||
self.name = name
|
||||
self.built = cl.Program(cl_ctx, prg).build(options=options)
|
||||
self.clprg = self.built.__getattr__(name)
|
||||
if argdtypes is not None: self.clprg.set_scalar_arg_dtypes(argdtypes)
|
||||
def __call__(self, *args):
|
||||
#print(f"running {self.name} with {args[0]} count {len(args)-2}")
|
||||
self.clprg(cl_queue, *args)
|
||||
def contiguous_view(x, name:str) -> str:
|
||||
return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}"
|
||||
|
||||
def contiguous_view(name:str, x:GPUBuffer) -> str:
|
||||
return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}"
|
||||
def unary_op(x, op:UnaryOps): return type(x)(x.shape)._processing_op([("A", x)], code_for_op[op])
|
||||
def binary_op(x, op:BinaryOps, y:GPUBuffer): return type(x)(x.shape)._processing_op([("A", x), ("B", y)], code_for_op[op])
|
||||
def contiguous_op(x): return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP)
|
||||
|
||||
def _processing_op(out_shape: Tuple[int], bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C=None):
|
||||
ret = GPUBuffer(out_shape)
|
||||
options = []
|
||||
def movement_op(x, op:MovementOps, arg) -> GPUBuffer:
|
||||
ret = GPUBuffer(x.st, x)
|
||||
ret.shape = ret.st.movement_op(op, arg).shape
|
||||
return ret
|
||||
|
||||
if C is not None:
|
||||
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "ys", "xs", "dx", "dy", "px", "py", "groups", "rcout", "cin"])
|
||||
params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]]
|
||||
if C.px == 0 and C.py == 0: options.append("-DALLVALID")
|
||||
if C.oy == 1 and C.ox == 1: options.append("-DONEBYONE")
|
||||
global_size = [C.bs*C.cout, C.oy, C.ox]
|
||||
assert bufs[0][0] == "input" and bufs[1][0] == "weight"
|
||||
ewbufs = bufs[2:] # input and weight are consumed by the convs
|
||||
else:
|
||||
ints, params = '', []
|
||||
options.append("-DNOCONV")
|
||||
global_size = [prod(ret.shape), 1, 1]
|
||||
ewbufs = bufs
|
||||
def processing_op(x, op:ProcessingOps, w:GPUBuffer, C:ConvArgs):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
|
||||
|
||||
elementwise_prefix = '\n'.join([contiguous_view(name, buf) for name, buf in ewbufs])+ \
|
||||
"inline float _ewop("+','.join(["int gid", "float acc"]+[f"__global const float *{name}_g" for name, _ in ewbufs])+") {"+ \
|
||||
'\n'.join([f"float {name} = get_{name}({name}_g, gid);" for name, _ in ewbufs])+ \
|
||||
f"return {code}; }}"
|
||||
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int]):
|
||||
ret = GPUBuffer(new_shape)
|
||||
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
|
||||
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
conv_params = ["__global float* restrict output"] + \
|
||||
[f"__global const float *{name}_g" for name, _ in bufs] + \
|
||||
[x[0] for x in params]
|
||||
conv_prg = CLProgram("conv", elementwise_prefix+"""
|
||||
__kernel void conv("""+','.join(conv_params)+""") {
|
||||
float acc = 0.0;
|
||||
int gid = get_global_id(0);
|
||||
"""+ints+"""
|
||||
# reverse operation of expand, this validates inputs
|
||||
st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, x.shape)
|
||||
# this takes a ret index to an inp index, indexing 0 on the reduced strides
|
||||
view = View(ret.shape, strides_for_shape(x.shape))
|
||||
|
||||
#ifndef NOCONV
|
||||
int B = gid/(groups*rcout); // range 0-bs
|
||||
int g = (gid/rcout)%groups;
|
||||
int c = gid % rcout;
|
||||
# generate loops with combined adjacent reduce axis
|
||||
acc = 1
|
||||
loop_start, loop_end = [], []
|
||||
for shp,stride in st.views[-1].shape_strides[::-1]:
|
||||
if stride == 0:
|
||||
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {shp}; axis_{len(loop_start)}++) {{")
|
||||
loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};")
|
||||
acc *= shp
|
||||
|
||||
#ifdef ONEBYONE
|
||||
int Y = 0;
|
||||
int X = 0;
|
||||
#else
|
||||
int Y = get_global_id(1); // range 0-oy
|
||||
int X = get_global_id(2); // range 0-ox
|
||||
gid = gid*oy*ox + Y*ox + X;
|
||||
#endif
|
||||
# TODO: support multistage reduces
|
||||
CLProgram("reduce", x.contiguous_view('A')+"""
|
||||
__kernel void reduce(__global const float *a_g, __global float *res_g) {
|
||||
int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+""";
|
||||
float out = """+start+""";\n"""+ \
|
||||
'\n'.join(loop_start[::-1])+"""
|
||||
float a = get_A(a_g, idx);
|
||||
"""+code+""";\n"""+ \
|
||||
'\n'.join(loop_end)+"""
|
||||
res_g[gid] = out;
|
||||
}""")([prod(ret.shape)], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
int IY = Y*ys;
|
||||
int IX = X*xs;
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None) -> GPUBuffer:
|
||||
options = []
|
||||
if C is not None:
|
||||
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "ys", "xs", "dx", "dy", "px", "py", "groups", "rcout", "cin"])
|
||||
params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]]
|
||||
if C.px == 0 and C.py == 0: options.append("-DALLVALID")
|
||||
if C.oy == 1 and C.ox == 1: options.append("-DONEBYONE")
|
||||
global_size = [C.bs*C.cout, C.oy, C.ox]
|
||||
assert bufs[0][0] == "input" and bufs[1][0] == "weight"
|
||||
ewbufs = bufs[2:] # input and weight are consumed by the convs
|
||||
kernel_name = "conv"
|
||||
else:
|
||||
ints, params = '', []
|
||||
options.append("-DNOCONV")
|
||||
global_size = [prod(ret.shape), 1, 1]
|
||||
ewbufs = bufs
|
||||
kernel_name = "elementwise"
|
||||
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
|
||||
int idx_y = y*dy + IY - py;
|
||||
int idx_x = x*dx + IX - px;
|
||||
#ifdef ALLVALID
|
||||
acc += input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \
|
||||
weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
#else
|
||||
int valid = (idx_y >= 0 && idx_y < iy && idx_x >= 0 && idx_x < ix);
|
||||
acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \
|
||||
weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
#endif
|
||||
} }
|
||||
}
|
||||
#endif
|
||||
elementwise_prefix = '\n'.join([buf.contiguous_view(name) for name, buf in ewbufs])+ \
|
||||
"inline float _ewop("+','.join(["int gid", "float acc"]+[f"__global const float *{name}_g" for name, _ in ewbufs])+") {"+ \
|
||||
'\n'.join([f"float {name} = get_{name}({name}_g, gid);" for name, _ in ewbufs])+ \
|
||||
f"return {code}; }}"
|
||||
|
||||
output[gid] = _ewop("""+','.join(["gid", "acc"]+[f"{name}_g" for name, _ in ewbufs])+""");
|
||||
}""", options=tuple(options), argdtypes=tuple([None]*(1+len(bufs)) + [np.int32]*len(params)))
|
||||
conv_prg(global_size, None, ret.cl, *[buf.cl for _, buf in bufs], *[x[1] for x in params])
|
||||
return ret
|
||||
conv_params = ["__global float* restrict output"] + \
|
||||
[f"__global const float *{name}_g" for name, _ in bufs] + \
|
||||
[x[0] for x in params]
|
||||
conv_prg = CLProgram(kernel_name, elementwise_prefix+f"__kernel void {kernel_name}("+','.join(conv_params)+""") {
|
||||
float acc = 0.0;
|
||||
int gid = get_global_id(0);
|
||||
"""+ints+"""
|
||||
|
||||
#ifndef NOCONV
|
||||
int B = gid/(groups*rcout); // range 0-bs
|
||||
int g = (gid/rcout)%groups;
|
||||
int c = gid % rcout;
|
||||
|
||||
#ifdef ONEBYONE
|
||||
int Y = 0;
|
||||
int X = 0;
|
||||
#else
|
||||
int Y = get_global_id(1); // range 0-oy
|
||||
int X = get_global_id(2); // range 0-ox
|
||||
gid = gid*oy*ox + Y*ox + X;
|
||||
#endif
|
||||
|
||||
int IY = Y*ys;
|
||||
int IX = X*xs;
|
||||
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
|
||||
int idx_y = y*dy + IY - py;
|
||||
int idx_x = x*dx + IX - px;
|
||||
#ifdef ALLVALID
|
||||
acc += input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \
|
||||
weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
#else
|
||||
int valid = (idx_y >= 0 && idx_y < iy && idx_x >= 0 && idx_x < ix);
|
||||
acc += valid * input_g[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \
|
||||
weight_g[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
#endif
|
||||
} }
|
||||
}
|
||||
#endif
|
||||
|
||||
output[gid] = _ewop("""+','.join(["gid", "acc"]+[f"{name}_g" for name, _ in ewbufs])+""");
|
||||
}""", options=tuple(options), argdtypes=tuple([None]*(1+len(bufs)) + [np.int32]*len(params)))
|
||||
conv_prg(global_size, None, ret.cl, *[buf.cl for _, buf in bufs], *[x[1] for x in params])
|
||||
return ret
|
||||
|
||||
|
||||
# gpu ops
|
||||
|
||||
code_for_op = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.NEG: "(-(A))", UnaryOps.SIGN: "sign(A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(B/A)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
}
|
||||
|
||||
def unary_op(op, x): return _processing_op(x.shape, [("A", x)], code_for_op[op])
|
||||
def binary_op(op, x, y): return _processing_op(x.shape, [("A", x), ("B", y)], code_for_op[op])
|
||||
def contiguous(x:GPUBuffer): return x if x.st.contiguous else unary_op(UnaryOps.NOOP, x)
|
||||
|
||||
def movement_op(op, x, arg):
|
||||
ret = GPUBuffer(x.st, x)
|
||||
ret.shape = ret.st.movement_op(op, arg).shape
|
||||
return ret
|
||||
|
||||
def processing_op(op, x, w, C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return _processing_op(C.out_shape, [("input", contiguous(x)), ("weight", contiguous(w))], "acc", C)
|
||||
|
||||
def reduce_op(op, x, new_shape):
|
||||
ret = GPUBuffer(new_shape)
|
||||
if op == ReduceOps.SUM: code, start = "out += a", "0.0"
|
||||
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
# reverse operation of expand, this validates inputs
|
||||
st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, x.shape)
|
||||
# this takes a ret index to an inp index, indexing 0 on the reduced strides
|
||||
view = View(ret.shape, strides_for_shape(x.shape))
|
||||
|
||||
# generate loops with combined adjacent reduce axis
|
||||
acc = 1
|
||||
loop_start, loop_end = [], []
|
||||
for shp,stride in st.views[-1].shape_strides[::-1]:
|
||||
if stride == 0:
|
||||
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {shp}; axis_{len(loop_start)}++) {{")
|
||||
loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};")
|
||||
acc *= shp
|
||||
|
||||
# TODO: support multistage reduces
|
||||
CLProgram("reduce", contiguous_view('A', x)+"""
|
||||
__kernel void reduce(__global const float *a_g, __global float *res_g) {
|
||||
int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+""";
|
||||
float out = """+start+""";\n"""+ \
|
||||
'\n'.join(loop_start[::-1])+"""
|
||||
float a = get_A(a_g, idx);
|
||||
"""+code+""";\n"""+ \
|
||||
'\n'.join(loop_end)+"""
|
||||
res_g[gid] = out;
|
||||
}""")([prod(ret.shape)], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from tinygrad.llops.ops_cpu import CPUBuffer
|
||||
from tinygrad.ops import ProcessingOps
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
class TorchBuffer(torch.Tensor):
|
||||
@@ -17,14 +19,8 @@ class TorchBuffer(torch.Tensor):
|
||||
def getdtype(self):
|
||||
return np.float32
|
||||
|
||||
# ************* unary+binary+reduce+movement ops *************
|
||||
unary_op, binary_op, reduce_op, movement_op = CPUBuffer.unary_op, CPUBuffer.binary_op, CPUBuffer.reduce_op, CPUBuffer.movement_op
|
||||
|
||||
from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
|
||||
|
||||
# ************* processing ops *************
|
||||
|
||||
from tinygrad.ops import ProcessingOps
|
||||
|
||||
def processing_op(op,x,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
|
||||
def processing_op(x,op,w,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
|
||||
|
||||
@@ -6,12 +6,10 @@ from tinygrad.tensor import Function
|
||||
# ************* unary ops *************
|
||||
|
||||
class _UnaryOp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return ctx.unary_op(ctx.fop, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return ctx.binary_op(ctx.bop, input, grad_output)
|
||||
@@ -19,7 +17,6 @@ class _UnaryOp(Function):
|
||||
class ReLU(_UnaryOp):
|
||||
fop = UnaryOps.RELU
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
ret = ctx.unary_op(UnaryOps.SIGN, input)
|
||||
@@ -31,7 +28,6 @@ class Log(_UnaryOp):
|
||||
bop = BinaryOps.DIV
|
||||
|
||||
class Exp(_UnaryOp):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = ctx.unary_op(UnaryOps.EXP, input)
|
||||
ctx.save_for_backward(ret) # we save the output here, not the input
|
||||
@@ -42,24 +38,20 @@ class Exp(_UnaryOp):
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
ctx.save_for_backward(input.shape)
|
||||
return ctx.reduce_op(ReduceOps.SUM, input, reduce_shape(input.shape, axis))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shape_input, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.EXPAND, grad_output, shape_input)
|
||||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
ret = ctx.reduce_op(ReduceOps.MAX, input, reduce_shape(input.shape, axis))
|
||||
ctx.save_for_backward(input, ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, ret = ctx.saved_tensors
|
||||
|
||||
@@ -77,32 +69,26 @@ class Max(Function):
|
||||
# ************* binary ops *************
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
return ctx.binary_op(BinaryOps.ADD, x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output if ctx.needs_input_grad[0] else None, \
|
||||
grad_output if ctx.needs_input_grad[1] else None
|
||||
|
||||
class Sub(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
return ctx.binary_op(BinaryOps.SUB, x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output if ctx.needs_input_grad[0] else None, \
|
||||
ctx.unary_op(UnaryOps.NEG, grad_output) if ctx.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return ctx.binary_op(BinaryOps.MUL, x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
grad_x = ctx.binary_op(BinaryOps.MUL, y, grad_output) if ctx.needs_input_grad[0] else None
|
||||
@@ -110,13 +96,11 @@ class Mul(Function):
|
||||
return grad_x, grad_y
|
||||
|
||||
class Pow(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ret = ctx.binary_op(BinaryOps.POW, x, y)
|
||||
ctx.save_for_backward(x, y, ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y,powxy = ctx.saved_tensors
|
||||
grad_x, grad_y = None, None
|
||||
@@ -133,58 +117,48 @@ class Pow(Function):
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return ctx.movement_op(MovementOps.EXPAND, x, shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
|
||||
|
||||
class Flip(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, axis):
|
||||
ctx.save_for_backward(axis)
|
||||
return ctx.movement_op(MovementOps.FLIP, x, axis)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
axis, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
|
||||
return ctx.movement_op(MovementOps.RESHAPE, x, shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.RESHAPE, grad_output, in_shape)
|
||||
|
||||
class Permute(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, order=(1,0)):
|
||||
ctx.save_for_backward(order)
|
||||
return ctx.movement_op(MovementOps.PERMUTE, x, order)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
order, = ctx.saved_tensors
|
||||
norder = np.argsort(order).tolist()
|
||||
return ctx.movement_op(MovementOps.PERMUTE, grad_output, norder)
|
||||
|
||||
class Slice(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, arg=None):
|
||||
ctx.save_for_backward(x.shape, arg)
|
||||
return ctx.movement_op(MovementOps.SLICE, x, arg)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shape, arg = ctx.saved_tensors
|
||||
narg = [(0-p[0], grad_output.shape[i]+(shape[i]-p[1])) for i,p in enumerate(arg)]
|
||||
@@ -193,13 +167,11 @@ class Slice(Function):
|
||||
# ************* processing ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w, stride=1, groups=1, dilation=1, padding=0):
|
||||
C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation, padding=padding)
|
||||
ctx.save_for_backward(x,w,C)
|
||||
return ctx.processing_op(ProcessingOps.CONV, x, w, C)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, w, C = ctx.saved_tensors
|
||||
dx, dw = None, None
|
||||
|
||||
@@ -51,14 +51,14 @@ def log_op(optype, op, ret, inp):
|
||||
|
||||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
ret = ctx.op.unary_op(op, x)
|
||||
ret = x.unary_op(op)
|
||||
if 'LAZY' not in ctx.device: log_op(UnaryOps, op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
return ret
|
||||
|
||||
def reduce_op(ctx, op:ReduceOps, x, new_shape):
|
||||
ret = ctx.op.reduce_op(op, x, tuple(new_shape))
|
||||
ret = x.reduce_op(op, tuple(new_shape))
|
||||
if 'LAZY' not in ctx.device: log_op(ReduceOps, op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == tuple(new_shape)
|
||||
@@ -66,14 +66,14 @@ class Ops:
|
||||
|
||||
def binary_op(ctx, op:BinaryOps, x, y):
|
||||
assert x.shape == y.shape
|
||||
ret = ctx.op.binary_op(op, x, y)
|
||||
ret = x.binary_op(op, y)
|
||||
if 'LAZY' not in ctx.device: log_op(BinaryOps, op, ret, [x, y])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
return ret
|
||||
|
||||
def movement_op(ctx, op:MovementOps, x, arg):
|
||||
ret = ctx.op.movement_op(op, x, tuple(arg))
|
||||
ret = x.movement_op(op, tuple(arg))
|
||||
if 'LAZY' not in ctx.device: log_op(MovementOps, op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
# this check is slow
|
||||
@@ -81,7 +81,7 @@ class Ops:
|
||||
return ret
|
||||
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, C):
|
||||
ret = ctx.op.processing_op(op, x, y, C)
|
||||
ret = x.processing_op(op, y, C)
|
||||
if 'LAZY' not in ctx.device: log_op(ProcessingOps, op, ret, [x, y])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == C.out_shape
|
||||
|
||||
@@ -111,7 +111,7 @@ class Tensor:
|
||||
if not any(x.requires_grad for x in t0._ctx.parents):
|
||||
continue
|
||||
assert (t0.grad is not None)
|
||||
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
|
||||
grads = t0._ctx.backward(t0.grad.data)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
@@ -380,7 +380,6 @@ class Function(Ops):
|
||||
self.saved_tensors = []
|
||||
|
||||
buffer = property(lambda self: Device.buffers[self.device])
|
||||
op = property(lambda self: Device.llops[self.device])
|
||||
|
||||
def save_for_backward(self, *x):
|
||||
if self.requires_grad:
|
||||
@@ -389,7 +388,7 @@ class Function(Ops):
|
||||
@classmethod
|
||||
def apply(cls, *x:List[Tensor], **kwargs):
|
||||
ctx = cls(x[0].device, *x)
|
||||
ret = Tensor(cls.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
ret = Tensor(ctx.forward(*[t.data for t in x], **kwargs),
|
||||
device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
if ctx.requires_grad:
|
||||
ret._ctx = ctx # used by autograd engine
|
||||
|
||||
Reference in New Issue
Block a user