From 49abc09f7746aa10f1fc33cffc28928bed38f4f7 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Feb 2025 12:33:25 -0500 Subject: [PATCH 1/8] remove the reshapes in test_arange_2_reduce [pr] (#9063) --- test/test_arange.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/test/test_arange.py b/test/test_arange.py index 4d0bb79dc5..a9bdc3b89b 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -66,20 +66,17 @@ class TestArange(unittest.TestCase): return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)]) class TestIndexing(unittest.TestCase): - # update: passing after CAST_BEFORE_VIEW=1 deletion - # @unittest.expectedFailure def test_arange_2_reduce(self): needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous() needle[1337] = 1 needle.realize() with Context(NOOPT=1, FUSE_ARANGE=1): GlobalCounters.reset() - # TODO: it should work without these reshapes - out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum() + out = ((Tensor.arange(1,16385)-1)*needle).sum() sched = out.schedule() - assert len(sched) == 1 + self.assertEqual(len(sched), 1) run_schedule(sched) - assert out.item() == 1337, f"expected 1337, got {out.item()}" + self.assertEqual(out.item(), 1337) @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") def test_manual_index(self): @@ -95,7 +92,7 @@ class TestIndexing(unittest.TestCase): full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1)) X = full.sum(axis=(2,3)) sched = X.schedule() - assert len(sched) == 1 + self.assertEqual(len(sched), 1) run_schedule(sched) assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}" np.testing.assert_allclose(real_index, X.numpy()) @@ -111,7 +108,7 @@ class TestIndexing(unittest.TestCase): assert X.shape == (4,256) sched = X.schedule() # TODO: enable these asserts when the scheduler can handle this - #assert len(sched) == 1, f"{len(sched)} != 1" + #self.assertEqual(len(sched), 1) run_schedule(sched) #assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}" np.testing.assert_allclose(real_index, X.numpy()) @@ -126,7 +123,7 @@ class TestIndexing(unittest.TestCase): X = dataset[idxs] assert X.shape == (4,256) sched = X.schedule() - assert len(sched) == 2 + self.assertEqual(len(sched), 2) run_schedule(sched) assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}" np.testing.assert_allclose(real_index, X.numpy()) From 947c97e6ff8105d0a72982c65ec4023565cf87e3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Feb 2025 15:25:54 -0500 Subject: [PATCH 2/8] add test_sqrt to test_speed_v_torch (#9066) working on getting rid of llvm sqrt hack --- test/test_speed_v_torch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 6025c97c41..282ceed96c 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -190,6 +190,10 @@ class TestSpeed(unittest.TestCase): def f(a, b): return a.exp() helper_test_generic_square('exp', 2048, f, f, onearg=True) + def test_sqrt(self): + def f(a, b): return a.sqrt() + helper_test_generic_square('sqrt', 2048, f, f, onearg=True) + def test_relu(self): def f(a, b): return a.relu() helper_test_generic_square('relu', 4096, f, f, onearg=True) From e02e3b94c3372e21e7d9ecbc6948839b78685bfa Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Feb 2025 15:42:34 -0500 Subject: [PATCH 3/8] remove SQRT hack in llvm (#9067) replaced with xpow 0.5 in transcendental. fixed sqrt(0) backward --- test/imported/test_indexing.py | 4 ++-- test/test_nn.py | 4 ++-- test/test_ops.py | 6 ++---- tinygrad/codegen/rewriter.py | 2 ++ tinygrad/renderer/llvmir.py | 5 ----- 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index 22a9a3f30a..21c5b251cb 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -1054,7 +1054,7 @@ class TestIndexing(unittest.TestCase): one = Tensor(1, dtype=dtypes.int64) # non-scalar indexed with scalars - a = Tensor.randn(2, 3) + a = Tensor.randn(2, 3).realize() numpy_testing_assert_equal_helper(a[0], a[zero]) numpy_testing_assert_equal_helper(a[0][1], a[zero][one]) numpy_testing_assert_equal_helper(a[0, 1], a[zero, one]) @@ -1066,7 +1066,7 @@ class TestIndexing(unittest.TestCase): numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)]) # scalar indexed with scalar - r = Tensor.randn() + r = Tensor.randn().realize() with self.assertRaises(IndexError): r[:] with self.assertRaises(IndexError): diff --git a/test/test_nn.py b/test/test_nn.py index 1b2963feed..995c5be444 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -563,8 +563,8 @@ class TestNN(unittest.TestCase): layer.weight.shard_(devices, 3) layer.bias.shard_(devices, None) state_dict = { - 'weight': Tensor.randn(5, 3, 3, 3), - 'bias': Tensor.randn(5), + 'weight': Tensor.randn(5, 3, 3, 3).realize(), + 'bias': Tensor.randn(5).realize(), } load_state_dict(layer, state_dict) diff --git a/test/test_ops.py b/test/test_ops.py index 8bee773ae8..3aafcc0b60 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -641,9 +641,7 @@ class TestOps(unittest.TestCase): def test_sqrt(self): helper_test_op([(45,65)], lambda x: x.sqrt()) - if Device.DEFAULT not in ("LLVM", "DSP"): - # TODO: fix backward - helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]]) + helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]]) helper_test_op([()], lambda x: x.sqrt()) def test_rsqrt(self): helper_test_op([(45,65)], lambda x: x.rsqrt()) @@ -1406,7 +1404,7 @@ class TestOps(unittest.TestCase): def test_asinh(self): helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6) # NOTE: this one has larger atol - helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_atol=1e-6, low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303) def test_acosh(self): helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6) diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 4ddaebec70..9876579f90 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -129,6 +129,8 @@ powers_of_two = {2**i:i for i in range(64)} def get_late_rewrite_patterns(ops, force_transcendental=False): pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \ ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental] + # rewrite SQRT to xpow 0.5 + if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5)))) # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index d41512afe0..eebf0b8c68 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -77,8 +77,6 @@ llvm_rewrite = PatternMatcher([ f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None), # unary/binary/ternary ops - (UPat(Ops.SQRT, name="x"), lambda ctx,x: - f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(GroupOp.Binary, name="x"), lambda ctx,x: @@ -152,9 +150,6 @@ class LLVMRenderer(Renderer): f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"] for u in uops: - # hack for defining sqrt function (TODO: can we get a transcendental for this?) - if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None - if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" # NOTE: MallocAllocator promises 0x20 alignment From 9e91898941b067198edf0093a665a31413dca694 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Feb 2025 16:29:44 -0500 Subject: [PATCH 4/8] bert eval at the end of training (#9070) always eval at the last epoch --- examples/mlperf/model_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 33b807bbfc..ad479b3c22 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -801,7 +801,7 @@ def train_bert(): f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}") # ** eval loop ** - if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK): + if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps: if MLLOGGER and RUNMLPERF: MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i}) if getenv("RESET_STEP", 0): train_step_bert.reset() From e83905696eed289c407db550833abf23f9740be3 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Thu, 13 Feb 2025 22:30:20 +0100 Subject: [PATCH 5/8] Show install instructions when dawn library is missing (#9059) * Show install instructions when dawn library is missing * Handle missing dawn in ops_webgpu * Simplify * Solve f-string backlash error --- tinygrad/runtime/ops_webgpu.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 357680081a..3f83f93ef0 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -1,12 +1,16 @@ import functools, struct from tinygrad.device import Compiled, Allocator, Compiler from tinygrad.renderer.wgsl import WGSLRenderer -from tinygrad.helpers import round_up +from tinygrad.helpers import round_up, OSX from tinygrad.runtime.autogen import webgpu from typing import List, Any import ctypes -instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True))) +try: + instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True))) +except AttributeError: + raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else + "sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so")) def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8')) From 5ef48bbe0a7aa14063ede7b13188eb5676956f3a Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Feb 2025 16:51:21 -0500 Subject: [PATCH 6/8] swap order in rsqrt (#9069) fixed backward for 0 --- test/test_ops.py | 3 +-- tinygrad/tensor.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3aafcc0b60..283451f1c9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -645,8 +645,7 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: x.sqrt()) def test_rsqrt(self): helper_test_op([(45,65)], lambda x: x.rsqrt()) - # TODO: fix backward - helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]], forward_only=True) + helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]]) helper_test_op([()], lambda x: x.rsqrt()) def test_xor(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f180090e91..0cf6e46f5c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2652,7 +2652,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 3., 4.]).rsqrt().numpy()) ``` """ - return self.reciprocal().sqrt() + return self.sqrt().reciprocal() def sin(self): """ Computes the sine of the tensor element-wise. From 2d04a75a40f2e6a38234eac75224793108766a08 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 14 Feb 2025 01:28:10 +0200 Subject: [PATCH 7/8] start tracking bottom_up_rewrite in viz [pr] (#9071) * start tracking bottom_up_rewrite in viz [pr] * use the tracking matcher in test_viz --- test/unit/test_viz.py | 13 ++++++++----- tinygrad/ops.py | 10 +++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index ff84813d1e..43560e22ca 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -1,12 +1,13 @@ import unittest, decimal, json from tinygrad.dtype import dtypes from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, symbolic -from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt +from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto # NOTE: VIZ tests always use the tracked PatternMatcher instance symbolic = TrackedPatternMatcher(symbolic.patterns) +substitute = TrackedPatternMatcher(_substitute.patterns) class TestViz(unittest.TestCase): def setUp(self): @@ -89,16 +90,18 @@ class TestViz(unittest.TestCase): self.assertEqual(len(uop_to_json(a*1)), 2) self.assertEqual(len(uop_to_json(a*b)), 3) - @unittest.skip("TODO: bring this back with better testing") def test_bottom_up_rewrite(self): a = UOp.variable("a", 0, 10) b = UOp.variable("b", 0, 10) c = UOp.variable("c", 0, 10) - UOp.substitute(a+b, {a+b:c}) + @track_rewrites(named=True) + def fxn(sink): return graph_rewrite(sink, substitute, ctx={a+b:c}, bottom_up=True) + fxn(a+b) + #UOp.substitute(a+b, {a+b:c}) ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 1) - _, _, vals = ret[0][0] - self.assertEqual(len(vals.upats), 1) + _, m = ret[0] + self.assertEqual(m[0]["match_count"], 1) # NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ def test_rewrite_without_context(self): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4788c6715f..e9ed4c6257 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -841,6 +841,7 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict() class TrackedGraphRewrite: loc: tuple[str, int] # location that called graph_rewrite sink: UOp # the sink input to graph_rewrite + bottom_up: bool matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches tracked_keys:list[Any] = [] tracked_ctxs:list[list[TrackedGraphRewrite]] = [] @@ -925,17 +926,16 @@ class RewriteContext: return ret def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp: - if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True - tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)) + if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0: + tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up)) return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink) def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]: - if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True - tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)) + if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0: + tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up)) rewrite_ctx = RewriteContext(pm, ctx) return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]} - # *** most of symbolic lives here now *** def split_uop(x:UOp, sep:Ops): From 73af42aeabef13f16c776da41c59395cff0341f1 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Feb 2025 21:06:01 -0500 Subject: [PATCH 8/8] fix pow backward when base is 0 (#9075) --- test/test_ops.py | 14 ++++++++++++-- tinygrad/gradient.py | 4 +++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 283451f1c9..fd3e692ef2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -615,10 +615,20 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: 2.0**x) helper_test_op([()], lambda x: x**2.0) helper_test_op([()], lambda x: 2.0**x) - # TODO: fix backward - helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True) + helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]]) helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]]) + def test_pow_zero_tensor(self): + helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.3]]) + helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.0]]) + # TODO: fix WEBGPU + if Device.DEFAULT != "WEBGPU": + helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [-0.3]]) + def test_pow_zero_const(self): + helper_test_op(None, lambda x: x**0.3, vals=[[0.0]]) + helper_test_op(None, lambda x: x**0.0, vals=[[0.0]]) + helper_test_op(None, lambda x: x**-0.3, vals=[[0.0]]) + @unittest.skip("not supported") def test_pow_int(self): def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 54e6a3465f..a925a4c4c4 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -22,7 +22,9 @@ pm_gradient = PatternMatcher([ (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)), (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)), (UPat(Ops.ADD), lambda ctx: (ctx, ctx)), - (UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))), + (UPat(Ops.POW, name="ret"), lambda ctx, ret: + (ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]), + ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))), (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)), (ret.src[0]