mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
REQUIRES_SIMPLE_REDUCE
This commit is contained in:
@@ -112,6 +112,7 @@ class GPUBuffer:
|
||||
def movement_op(x, op:MovementOps, arg) -> GPUBuffer: return type(x)(ShapeTracker(x.st).movement_op(op, arg), x)
|
||||
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): return type(x)(new_shape)._processing_op([("A", x)], code="acc", earlycode=GPUBuffer.code_for_op[op], earlybufs=set("A"), start=GPUBuffer.start_for_op[op])
|
||||
|
||||
#REQUIRES_SIMPLE_REDUCE = True
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
|
||||
assert C is None
|
||||
|
||||
|
||||
@@ -245,8 +245,16 @@ class LazyBuffer:
|
||||
def binary_op(x:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, x, y)
|
||||
def contiguous_op(x:LazyBuffer) -> LazyBuffer: return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP)
|
||||
|
||||
# TODO: permute to put all the reduce axis at the end
|
||||
def reduce_op(x:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape))) if x.shape != tuple(new_shape) else x
|
||||
if x.shape == tuple(new_shape): return x
|
||||
if getattr(x.dbuffer, "REQUIRES_SIMPLE_REDUCE", False) and (len(new_shape) != 2 or new_shape[1] != 1):
|
||||
num, red = prod([s for s,n in zip(x.shape, new_shape) if n != 1]), prod([s for s,n in zip(x.shape, new_shape) if n == 1])
|
||||
x = x.movement_op(MovementOps.PERMUTE, [i for i,n in enumerate(new_shape) if n != 1] + [i for i,n in enumerate(new_shape) if n == 1])
|
||||
x = x.movement_op(MovementOps.RESHAPE, (num, red))
|
||||
return x.reduce_op(op, (num, 1)).movement_op(MovementOps.RESHAPE, new_shape)
|
||||
else:
|
||||
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape)))
|
||||
|
||||
# syntactic sugar around PAD and SHRINK
|
||||
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)
|
||||
|
||||
@@ -23,10 +23,8 @@ class Tensor:
|
||||
if isinstance(data, np.ndarray):
|
||||
if data.shape == tuple(): data = data.reshape((1,))
|
||||
self.lazydata = LazyBuffer.fromCPU(data.astype(np.float32), device)
|
||||
elif isinstance(data, LazyBuffer):
|
||||
self.lazydata = data
|
||||
else:
|
||||
raise Exception(f"can't create Tensor from {data}")
|
||||
elif isinstance(data, LazyBuffer): self.lazydata = data
|
||||
else: raise Exception(f"can't create Tensor from {data}")
|
||||
|
||||
# tensors have gradients, buffers do not
|
||||
self.grad : Optional[Tensor] = None
|
||||
@@ -55,8 +53,7 @@ class Tensor:
|
||||
return self
|
||||
|
||||
def assign(self, x):
|
||||
if not isinstance(x, Tensor):
|
||||
x = Tensor(x)
|
||||
if not isinstance(x, Tensor): x = Tensor(x)
|
||||
assert self.shape == x.shape
|
||||
self.lazydata = x.lazydata
|
||||
return x
|
||||
@@ -127,8 +124,7 @@ class Tensor:
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, \
|
||||
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
|
||||
# ***** non first class ops (hlops) *****
|
||||
|
||||
Reference in New Issue
Block a user