Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-02-14 06:28:37 +00:00
13 changed files with 55 additions and 41 deletions

View File

@@ -1047,7 +1047,7 @@ def train_bert():
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
# ** eval loop **
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i})
if getenv("RESET_STEP", 0): train_step_bert.reset()

View File

@@ -1054,7 +1054,7 @@ class TestIndexing(unittest.TestCase):
one = Tensor(1, dtype=dtypes.int64)
# non-scalar indexed with scalars
a = Tensor.randn(2, 3)
a = Tensor.randn(2, 3).realize()
numpy_testing_assert_equal_helper(a[0], a[zero])
numpy_testing_assert_equal_helper(a[0][1], a[zero][one])
numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
@@ -1066,7 +1066,7 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])
# scalar indexed with scalar
r = Tensor.randn()
r = Tensor.randn().realize()
with self.assertRaises(IndexError):
r[:]
with self.assertRaises(IndexError):

View File

@@ -66,20 +66,17 @@ class TestArange(unittest.TestCase):
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)])
class TestIndexing(unittest.TestCase):
# update: passing after CAST_BEFORE_VIEW=1 deletion
# @unittest.expectedFailure
def test_arange_2_reduce(self):
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
needle[1337] = 1
needle.realize()
with Context(NOOPT=1, FUSE_ARANGE=1):
GlobalCounters.reset()
# TODO: it should work without these reshapes
out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
out = ((Tensor.arange(1,16385)-1)*needle).sum()
sched = out.schedule()
assert len(sched) == 1
self.assertEqual(len(sched), 1)
run_schedule(sched)
assert out.item() == 1337, f"expected 1337, got {out.item()}"
self.assertEqual(out.item(), 1337)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_manual_index(self):
@@ -95,7 +92,7 @@ class TestIndexing(unittest.TestCase):
full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1))
X = full.sum(axis=(2,3))
sched = X.schedule()
assert len(sched) == 1
self.assertEqual(len(sched), 1)
run_schedule(sched)
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(real_index, X.numpy())
@@ -111,7 +108,7 @@ class TestIndexing(unittest.TestCase):
assert X.shape == (4,256)
sched = X.schedule()
# TODO: enable these asserts when the scheduler can handle this
#assert len(sched) == 1, f"{len(sched)} != 1"
#self.assertEqual(len(sched), 1)
run_schedule(sched)
#assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(real_index, X.numpy())
@@ -126,7 +123,7 @@ class TestIndexing(unittest.TestCase):
X = dataset[idxs]
assert X.shape == (4,256)
sched = X.schedule()
assert len(sched) == 2
self.assertEqual(len(sched), 2)
run_schedule(sched)
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
np.testing.assert_allclose(real_index, X.numpy())

View File

@@ -563,8 +563,8 @@ class TestNN(unittest.TestCase):
layer.weight.shard_(devices, 3)
layer.bias.shard_(devices, None)
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3),
'bias': Tensor.randn(5),
'weight': Tensor.randn(5, 3, 3, 3).realize(),
'bias': Tensor.randn(5).realize(),
}
load_state_dict(layer, state_dict)

View File

@@ -615,10 +615,20 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: 2.0**x)
helper_test_op([()], lambda x: x**2.0)
helper_test_op([()], lambda x: 2.0**x)
# TODO: fix backward
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]])
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]])
def test_pow_zero_tensor(self):
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.3]])
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.0]])
# TODO: fix WEBGPU
if Device.DEFAULT != "WEBGPU":
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [-0.3]])
def test_pow_zero_const(self):
helper_test_op(None, lambda x: x**0.3, vals=[[0.0]])
helper_test_op(None, lambda x: x**0.0, vals=[[0.0]])
helper_test_op(None, lambda x: x**-0.3, vals=[[0.0]])
@unittest.skip("not supported")
def test_pow_int(self):
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)
@@ -641,14 +651,11 @@ class TestOps(unittest.TestCase):
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt())
if Device.DEFAULT not in ("LLVM", "DSP"):
# TODO: fix backward
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
helper_test_op([()], lambda x: x.sqrt())
def test_rsqrt(self):
helper_test_op([(45,65)], lambda x: x.rsqrt())
# TODO: fix backward
helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]], forward_only=True)
helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]])
helper_test_op([()], lambda x: x.rsqrt())
def test_xor(self):
@@ -1406,7 +1413,7 @@ class TestOps(unittest.TestCase):
def test_asinh(self):
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6)
# NOTE: this one has larger atol
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303)
def test_acosh(self):
helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6)

View File

@@ -190,6 +190,10 @@ class TestSpeed(unittest.TestCase):
def f(a, b): return a.exp()
helper_test_generic_square('exp', 2048, f, f, onearg=True)
def test_sqrt(self):
def f(a, b): return a.sqrt()
helper_test_generic_square('sqrt', 2048, f, f, onearg=True)
def test_relu(self):
def f(a, b): return a.relu()
helper_test_generic_square('relu', 4096, f, f, onearg=True)

View File

@@ -1,12 +1,13 @@
import unittest, decimal, json
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto
# NOTE: VIZ tests always use the tracked PatternMatcher instance
symbolic = TrackedPatternMatcher(symbolic.patterns)
substitute = TrackedPatternMatcher(_substitute.patterns)
class TestViz(unittest.TestCase):
def setUp(self):
@@ -89,16 +90,18 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(uop_to_json(a*1)), 2)
self.assertEqual(len(uop_to_json(a*b)), 3)
@unittest.skip("TODO: bring this back with better testing")
def test_bottom_up_rewrite(self):
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
c = UOp.variable("c", 0, 10)
UOp.substitute(a+b, {a+b:c})
@track_rewrites(named=True)
def fxn(sink): return graph_rewrite(sink, substitute, ctx={a+b:c}, bottom_up=True)
fxn(a+b)
#UOp.substitute(a+b, {a+b:c})
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 1)
_, _, vals = ret[0][0]
self.assertEqual(len(vals.upats), 1)
_, m = ret[0]
self.assertEqual(m[0]["match_count"], 1)
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
def test_rewrite_without_context(self):

View File

@@ -129,6 +129,8 @@ powers_of_two = {2**i:i for i in range(64)}
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops:
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]

View File

@@ -22,7 +22,9 @@ pm_gradient = PatternMatcher([
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
(UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))),
(UPat(Ops.POW, name="ret"), lambda ctx, ret:
(ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]),
ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))),
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),

View File

@@ -841,6 +841,7 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict()
class TrackedGraphRewrite:
loc: tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink input to graph_rewrite
bottom_up: bool
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
tracked_keys:list[Any] = []
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
@@ -925,17 +926,16 @@ class RewriteContext:
return ret
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
rewrite_ctx = RewriteContext(pm, ctx)
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
# *** most of symbolic lives here now ***
def split_uop(x:UOp, sep:Ops):

View File

@@ -77,8 +77,6 @@ llvm_rewrite = PatternMatcher([
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
# unary/binary/ternary ops
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
@@ -152,9 +150,6 @@ class LLVMRenderer(Renderer):
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
for u in uops:
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
# NOTE: MallocAllocator promises 0x20 alignment

View File

@@ -1,12 +1,16 @@
import functools, struct
from tinygrad.device import Compiled, Allocator, Compiler
from tinygrad.renderer.wgsl import WGSLRenderer
from tinygrad.helpers import round_up
from tinygrad.helpers import round_up, OSX
from tinygrad.runtime.autogen import webgpu
from typing import List, Any
import ctypes
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
try:
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
except AttributeError:
raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else
"sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so"))
def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8'))

View File

@@ -2652,7 +2652,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
```
"""
return self.reciprocal().sqrt()
return self.sqrt().reciprocal()
def sin(self):
"""
Computes the sine of the tensor element-wise.