REQUIRES_SIMPLE_REDUCE

This commit is contained in:
George Hotz
2022-07-19 11:42:14 -07:00
parent acbeaf0ba9
commit 46e7dfade1
3 changed files with 14 additions and 9 deletions

View File

@@ -112,6 +112,7 @@ class GPUBuffer:
def movement_op(x, op:MovementOps, arg) -> GPUBuffer: return type(x)(ShapeTracker(x.st).movement_op(op, arg), x)
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): return type(x)(new_shape)._processing_op([("A", x)], code="acc", earlycode=GPUBuffer.code_for_op[op], earlybufs=set("A"), start=GPUBuffer.start_for_op[op])
#REQUIRES_SIMPLE_REDUCE = True
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
assert C is None

View File

@@ -245,8 +245,16 @@ class LazyBuffer:
def binary_op(x:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, x, y)
def contiguous_op(x:LazyBuffer) -> LazyBuffer: return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP)
# TODO: permute to put all the reduce axis at the end
def reduce_op(x:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape))) if x.shape != tuple(new_shape) else x
if x.shape == tuple(new_shape): return x
if getattr(x.dbuffer, "REQUIRES_SIMPLE_REDUCE", False) and (len(new_shape) != 2 or new_shape[1] != 1):
num, red = prod([s for s,n in zip(x.shape, new_shape) if n != 1]), prod([s for s,n in zip(x.shape, new_shape) if n == 1])
x = x.movement_op(MovementOps.PERMUTE, [i for i,n in enumerate(new_shape) if n != 1] + [i for i,n in enumerate(new_shape) if n == 1])
x = x.movement_op(MovementOps.RESHAPE, (num, red))
return x.reduce_op(op, (num, 1)).movement_op(MovementOps.RESHAPE, new_shape)
else:
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape)))
# syntactic sugar around PAD and SHRINK
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)

View File

@@ -23,10 +23,8 @@ class Tensor:
if isinstance(data, np.ndarray):
if data.shape == tuple(): data = data.reshape((1,))
self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device)
elif isinstance(data, LazyBuffer):
self.lazydata = data
else:
raise Exception(f"can't create Tensor from {data}")
elif isinstance(data, LazyBuffer): self.lazydata = data
else: raise Exception(f"can't create Tensor from {data}")
# tensors have gradients, buffers do not
self.grad : Optional[Tensor] = None
@@ -55,8 +53,7 @@ class Tensor:
return self
def assign(self, x):
if not isinstance(x, Tensor):
x = Tensor(x)
if not isinstance(x, Tensor): x = Tensor(x)
assert self.shape == x.shape
self.lazydata = x.lazydata
return x
@@ -127,8 +124,7 @@ class Tensor:
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, \
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
assert g.shape == t.shape, f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
# ***** non first class ops (hlops) *****