mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix up contiguous (#1978)
This commit is contained in:
@@ -312,5 +312,15 @@ class TestSchedule(unittest.TestCase):
|
||||
out = bb(x)
|
||||
check_schedule(out, 4)
|
||||
|
||||
def test_contiguous_while_contiguous(self):
|
||||
x = Tensor.empty(1, 64, 32, 32)
|
||||
out = x.contiguous()
|
||||
check_schedule(out, 1, filter_loadops=False)
|
||||
|
||||
def test_contiguous_while_not_contiguous(self):
|
||||
x = Tensor.empty(1, 64, 32, 32)
|
||||
out = x.permute(0,2,3,1).contiguous()
|
||||
check_schedule(out, 2, filter_loadops=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -161,8 +161,9 @@ class LazyBuffer:
|
||||
if seen is None: seen = set()
|
||||
if self in seen or self.realized or self.is_unrealized_const(): return []
|
||||
seen.add(self)
|
||||
if self.optype is MovementOps: return self.base.schedule(seen)
|
||||
if self.base != self: return self.base.schedule(seen)
|
||||
|
||||
# rewrite unbased CONTIGUOUS into UnaryOps.NOOP
|
||||
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
|
||||
|
||||
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
|
||||
@@ -175,12 +176,6 @@ class LazyBuffer:
|
||||
# TODO: this belongs in the schedule in some way
|
||||
self.var_vals = dict(sorted(merge_dicts([self.var_vals] + [buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
# contiguous can be a copy. must do this after the image hack
|
||||
if self.op.op == LoadOps.CONTIGUOUS:
|
||||
src = cast(LazyBuffer, self.op.src[0])
|
||||
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
|
||||
op = self.op
|
||||
|
||||
# run the ast and log the op
|
||||
op, base_bufs = _replace_bufferops(op)
|
||||
return ret + [(op, self, tuple(base_bufs))]
|
||||
@@ -198,6 +193,11 @@ class LazyBuffer:
|
||||
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const():
|
||||
# this will turn into nothing, it's based and a copy
|
||||
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals, base=self.base)
|
||||
# real contiguous, this will turn into a UnaryOps.NOOP
|
||||
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -20,11 +20,6 @@ def fix_schedule_for_images(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBu
|
||||
if isinstance(buffers[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
|
||||
buffers[b.arg.idx-1].dtype = dtypes.float32
|
||||
|
||||
# fix the contiguous dtype, no cast required
|
||||
for op,out,buffers in schedule:
|
||||
if op.op == LoadOps.CONTIGUOUS and out.dtype != buffers[0].dtype:
|
||||
out.dtype = buffers[0].dtype = dtypes.float32
|
||||
|
||||
# now fix up the schedule to reflect the new dtypes
|
||||
fixed_schedule = []
|
||||
for op,out,buffers in schedule:
|
||||
@@ -82,11 +77,6 @@ def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
|
||||
# *** one op LoadOps ***
|
||||
|
||||
def _realize_contiguous(buffer: LazyBuffer, src: LazyBuffer) -> None:
|
||||
# this is just a copy now, if it's not a copy schedule will handle it
|
||||
buffer.realized = src.realized
|
||||
assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}"
|
||||
|
||||
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
|
||||
assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}"
|
||||
assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from"
|
||||
@@ -110,7 +100,6 @@ def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
|
||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
LoadOps.EMPTY: _realize_empty,
|
||||
LoadOps.RAND: _realize_rand,
|
||||
LoadOps.CONTIGUOUS: _realize_contiguous,
|
||||
LoadOps.FROM: _realize_from,
|
||||
LoadOps.CUSTOM: _realize_custom,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user