diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 27375804c3..18c76fd5c2 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1047,7 +1047,7 @@ def train_bert(): f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}") # ** eval loop ** - if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK): + if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps: if MLLOGGER and RUNMLPERF: MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i}) if getenv("RESET_STEP", 0): train_step_bert.reset() diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index 22a9a3f30a..21c5b251cb 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -1054,7 +1054,7 @@ class TestIndexing(unittest.TestCase): one = Tensor(1, dtype=dtypes.int64) # non-scalar indexed with scalars - a = Tensor.randn(2, 3) + a = Tensor.randn(2, 3).realize() numpy_testing_assert_equal_helper(a[0], a[zero]) numpy_testing_assert_equal_helper(a[0][1], a[zero][one]) numpy_testing_assert_equal_helper(a[0, 1], a[zero, one]) @@ -1066,7 +1066,7 @@ class TestIndexing(unittest.TestCase): numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)]) # scalar indexed with scalar - r = Tensor.randn() + r = Tensor.randn().realize() with self.assertRaises(IndexError): r[:] with self.assertRaises(IndexError): diff --git a/test/test_arange.py b/test/test_arange.py index 4d0bb79dc5..a9bdc3b89b 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -66,20 +66,17 @@ class TestArange(unittest.TestCase): return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)]) class TestIndexing(unittest.TestCase): - # update: passing after CAST_BEFORE_VIEW=1 deletion - # @unittest.expectedFailure def test_arange_2_reduce(self): needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous() needle[1337] = 1 needle.realize() with Context(NOOPT=1, FUSE_ARANGE=1): GlobalCounters.reset() - # TODO: it should work without these reshapes - out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum() + out = ((Tensor.arange(1,16385)-1)*needle).sum() sched = out.schedule() - assert len(sched) == 1 + self.assertEqual(len(sched), 1) run_schedule(sched) - assert out.item() == 1337, f"expected 1337, got {out.item()}" + self.assertEqual(out.item(), 1337) @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") def test_manual_index(self): @@ -95,7 +92,7 @@ class TestIndexing(unittest.TestCase): full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1)) X = full.sum(axis=(2,3)) sched = X.schedule() - assert len(sched) == 1 + self.assertEqual(len(sched), 1) run_schedule(sched) assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}" np.testing.assert_allclose(real_index, X.numpy()) @@ -111,7 +108,7 @@ class TestIndexing(unittest.TestCase): assert X.shape == (4,256) sched = X.schedule() # TODO: enable these asserts when the scheduler can handle this - #assert len(sched) == 1, f"{len(sched)} != 1" + #self.assertEqual(len(sched), 1) run_schedule(sched) #assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}" np.testing.assert_allclose(real_index, X.numpy()) @@ -126,7 +123,7 @@ class TestIndexing(unittest.TestCase): X = dataset[idxs] assert X.shape == (4,256) sched = X.schedule() - assert len(sched) == 2 + self.assertEqual(len(sched), 2) run_schedule(sched) assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}" np.testing.assert_allclose(real_index, X.numpy()) diff --git a/test/test_nn.py b/test/test_nn.py index 1b2963feed..995c5be444 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -563,8 +563,8 @@ class TestNN(unittest.TestCase): layer.weight.shard_(devices, 3) layer.bias.shard_(devices, None) state_dict = { - 'weight': Tensor.randn(5, 3, 3, 3), - 'bias': Tensor.randn(5), + 'weight': Tensor.randn(5, 3, 3, 3).realize(), + 'bias': Tensor.randn(5).realize(), } load_state_dict(layer, state_dict) diff --git a/test/test_ops.py b/test/test_ops.py index 8bee773ae8..fd3e692ef2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -615,10 +615,20 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: 2.0**x) helper_test_op([()], lambda x: x**2.0) helper_test_op([()], lambda x: 2.0**x) - # TODO: fix backward - helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True) + helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]]) helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]]) + def test_pow_zero_tensor(self): + helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.3]]) + helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.0]]) + # TODO: fix WEBGPU + if Device.DEFAULT != "WEBGPU": + helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [-0.3]]) + def test_pow_zero_const(self): + helper_test_op(None, lambda x: x**0.3, vals=[[0.0]]) + helper_test_op(None, lambda x: x**0.0, vals=[[0.0]]) + helper_test_op(None, lambda x: x**-0.3, vals=[[0.0]]) + @unittest.skip("not supported") def test_pow_int(self): def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True) @@ -641,14 +651,11 @@ class TestOps(unittest.TestCase): def test_sqrt(self): helper_test_op([(45,65)], lambda x: x.sqrt()) - if Device.DEFAULT not in ("LLVM", "DSP"): - # TODO: fix backward - helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]]) + helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]]) helper_test_op([()], lambda x: x.sqrt()) def test_rsqrt(self): helper_test_op([(45,65)], lambda x: x.rsqrt()) - # TODO: fix backward - helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]], forward_only=True) + helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]]) helper_test_op([()], lambda x: x.rsqrt()) def test_xor(self): @@ -1406,7 +1413,7 @@ class TestOps(unittest.TestCase): def test_asinh(self): helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6) # NOTE: this one has larger atol - helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_atol=1e-6, low=-300, high=-297) helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303) def test_acosh(self): helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6) diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 6025c97c41..282ceed96c 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -190,6 +190,10 @@ class TestSpeed(unittest.TestCase): def f(a, b): return a.exp() helper_test_generic_square('exp', 2048, f, f, onearg=True) + def test_sqrt(self): + def f(a, b): return a.sqrt() + helper_test_generic_square('sqrt', 2048, f, f, onearg=True) + def test_relu(self): def f(a, b): return a.relu() helper_test_generic_square('relu', 4096, f, f, onearg=True) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index ff84813d1e..43560e22ca 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -1,12 +1,13 @@ import unittest, decimal, json from tinygrad.dtype import dtypes from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, symbolic -from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt +from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto # NOTE: VIZ tests always use the tracked PatternMatcher instance symbolic = TrackedPatternMatcher(symbolic.patterns) +substitute = TrackedPatternMatcher(_substitute.patterns) class TestViz(unittest.TestCase): def setUp(self): @@ -89,16 +90,18 @@ class TestViz(unittest.TestCase): self.assertEqual(len(uop_to_json(a*1)), 2) self.assertEqual(len(uop_to_json(a*b)), 3) - @unittest.skip("TODO: bring this back with better testing") def test_bottom_up_rewrite(self): a = UOp.variable("a", 0, 10) b = UOp.variable("b", 0, 10) c = UOp.variable("c", 0, 10) - UOp.substitute(a+b, {a+b:c}) + @track_rewrites(named=True) + def fxn(sink): return graph_rewrite(sink, substitute, ctx={a+b:c}, bottom_up=True) + fxn(a+b) + #UOp.substitute(a+b, {a+b:c}) ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 1) - _, _, vals = ret[0][0] - self.assertEqual(len(vals.upats), 1) + _, m = ret[0] + self.assertEqual(m[0]["match_count"], 1) # NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ def test_rewrite_without_context(self): diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 4ddaebec70..9876579f90 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -129,6 +129,8 @@ powers_of_two = {2**i:i for i in range(64)} def get_late_rewrite_patterns(ops, force_transcendental=False): pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \ ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental] + # rewrite SQRT to xpow 0.5 + if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5)))) # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 54e6a3465f..a925a4c4c4 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -22,7 +22,9 @@ pm_gradient = PatternMatcher([ (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)), (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)), (UPat(Ops.ADD), lambda ctx: (ctx, ctx)), - (UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))), + (UPat(Ops.POW, name="ret"), lambda ctx, ret: + (ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]), + ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))), (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)), (ret.src[0] UOp: - if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True - tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)) + if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0: + tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up)) return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink) def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]: - if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True - tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink)) + if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0: + tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up)) rewrite_ctx = RewriteContext(pm, ctx) return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]} - # *** most of symbolic lives here now *** def split_uop(x:UOp, sep:Ops): diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index d41512afe0..eebf0b8c68 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -77,8 +77,6 @@ llvm_rewrite = PatternMatcher([ f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None), # unary/binary/ternary ops - (UPat(Ops.SQRT, name="x"), lambda ctx,x: - f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), (UPat(GroupOp.Binary, name="x"), lambda ctx,x: @@ -152,9 +150,6 @@ class LLVMRenderer(Renderer): f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"] for u in uops: - # hack for defining sqrt function (TODO: can we get a transcendental for this?) - if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None - if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" # NOTE: MallocAllocator promises 0x20 alignment diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 357680081a..3f83f93ef0 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -1,12 +1,16 @@ import functools, struct from tinygrad.device import Compiled, Allocator, Compiler from tinygrad.renderer.wgsl import WGSLRenderer -from tinygrad.helpers import round_up +from tinygrad.helpers import round_up, OSX from tinygrad.runtime.autogen import webgpu from typing import List, Any import ctypes -instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True))) +try: + instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True))) +except AttributeError: + raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else + "sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so")) def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8')) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f180090e91..0cf6e46f5c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2652,7 +2652,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 3., 4.]).rsqrt().numpy()) ``` """ - return self.reciprocal().sqrt() + return self.sqrt().reciprocal() def sin(self): """ Computes the sine of the tensor element-wise.