diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index e5ba6f5849..85315a8a3b 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -9,13 +9,13 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad import dtypes from tinygrad.shape.view import View -class InvalidLazyOpException(Exception): pass +class InvalidASTException(Exception): pass def helper_test_verify_ast(*stores:UOp) -> Kernel: sink = UOp(UOps.SINK, None, stores) if DEBUG >= 3: for op in stores: print(op) try: verify_ast(sink) - except AssertionError as e: raise InvalidLazyOpException(e.args) + except AssertionError as e: raise InvalidASTException(e.args) k = Kernel(sink) k.linearize() if DEBUG >= 6: print_uops(k.uops) @@ -42,14 +42,14 @@ class TestVerifyAST(unittest.TestCase): a = UOp(UOps.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop())) b = UOp(UOps.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop())) st1 = UOp.store(bufs[1], ShapeTracker.from_shape((32, 32)).to_uop(), a+b) - with self.assertRaises(InvalidLazyOpException): helper_test_verify_ast(st0, st1) + with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1) def test_no_implicit_broadcasting(self): bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop())) b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,))) st = UOp(UOps.STORE, None, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b)) - with self.assertRaises(InvalidLazyOpException): helper_test_verify_ast(st) + with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) def test_shrink_ok(self): bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] @@ -63,14 +63,14 @@ class TestVerifyAST(unittest.TestCase): a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r) - with self.assertRaisesRegex(InvalidLazyOpException, "implicit expand"): helper_test_verify_ast(st) + with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) def test_reduce_add_store(self): bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a) - with self.assertRaisesRegex(InvalidLazyOpException, "implicit expand"): helper_test_verify_ast(st) + with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 70690245de..b9b74dcdc2 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -52,7 +52,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: - """recursively create a lazyop""" + """recursively create a UOp""" if buf is not buf.base: st, buf = buf.st+st, buf.base if (buf, st) in cache: return cache[(buf, st)] assert buf.op is not None, "base must be a base itself" @@ -90,13 +90,13 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (buf.op, rinfo[1]))) # elementwise ops pass shapetracker - in_ops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs) + in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs) if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}: assert buf in outputs, f"{buf.op} must be writable" - return in_ops[0] - if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_ops)) - if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_ops)) - return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_ops, buf.op)) + return in_uops[0] + if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops)) + if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) + return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]: permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis @@ -180,8 +180,8 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> if vv: var_vals.update(vv) ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i) ast.append(UOp(UOps.STORE, None, (ubuf, output_st.to_uop(), src))) - return LBScheduleItem(UOp(UOps.SINK, None, tuple(ast)), outs, list(inputs), var_vals, - dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])) + sink = UOp(UOps.SINK, None, tuple(ast)) + return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])) # *** DAG creation: decide which LazyBuffers should realize *** diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 6d039a3b8e..928b975138 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -31,7 +31,7 @@ class LazyBuffer: self._base: Optional[LazyBuffer] = None if base is None: # properties on base - self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps + self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps assert self.op is not MetaOps.ASSIGN or srcs[1].base.realized is not None, "assign target must be realized" if self.op is MetaOps.VIEW: