mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
@@ -1047,7 +1047,7 @@ def train_bert():
|
||||
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
|
||||
|
||||
# ** eval loop **
|
||||
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK):
|
||||
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:
|
||||
if MLLOGGER and RUNMLPERF:
|
||||
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i})
|
||||
if getenv("RESET_STEP", 0): train_step_bert.reset()
|
||||
|
||||
@@ -1054,7 +1054,7 @@ class TestIndexing(unittest.TestCase):
|
||||
one = Tensor(1, dtype=dtypes.int64)
|
||||
|
||||
# non-scalar indexed with scalars
|
||||
a = Tensor.randn(2, 3)
|
||||
a = Tensor.randn(2, 3).realize()
|
||||
numpy_testing_assert_equal_helper(a[0], a[zero])
|
||||
numpy_testing_assert_equal_helper(a[0][1], a[zero][one])
|
||||
numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
|
||||
@@ -1066,7 +1066,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])
|
||||
|
||||
# scalar indexed with scalar
|
||||
r = Tensor.randn()
|
||||
r = Tensor.randn().realize()
|
||||
with self.assertRaises(IndexError):
|
||||
r[:]
|
||||
with self.assertRaises(IndexError):
|
||||
|
||||
@@ -66,20 +66,17 @@ class TestArange(unittest.TestCase):
|
||||
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)])
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
# update: passing after CAST_BEFORE_VIEW=1 deletion
|
||||
# @unittest.expectedFailure
|
||||
def test_arange_2_reduce(self):
|
||||
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
|
||||
needle[1337] = 1
|
||||
needle.realize()
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
GlobalCounters.reset()
|
||||
# TODO: it should work without these reshapes
|
||||
out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
|
||||
out = ((Tensor.arange(1,16385)-1)*needle).sum()
|
||||
sched = out.schedule()
|
||||
assert len(sched) == 1
|
||||
self.assertEqual(len(sched), 1)
|
||||
run_schedule(sched)
|
||||
assert out.item() == 1337, f"expected 1337, got {out.item()}"
|
||||
self.assertEqual(out.item(), 1337)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_manual_index(self):
|
||||
@@ -95,7 +92,7 @@ class TestIndexing(unittest.TestCase):
|
||||
full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1))
|
||||
X = full.sum(axis=(2,3))
|
||||
sched = X.schedule()
|
||||
assert len(sched) == 1
|
||||
self.assertEqual(len(sched), 1)
|
||||
run_schedule(sched)
|
||||
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
|
||||
np.testing.assert_allclose(real_index, X.numpy())
|
||||
@@ -111,7 +108,7 @@ class TestIndexing(unittest.TestCase):
|
||||
assert X.shape == (4,256)
|
||||
sched = X.schedule()
|
||||
# TODO: enable these asserts when the scheduler can handle this
|
||||
#assert len(sched) == 1, f"{len(sched)} != 1"
|
||||
#self.assertEqual(len(sched), 1)
|
||||
run_schedule(sched)
|
||||
#assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
|
||||
np.testing.assert_allclose(real_index, X.numpy())
|
||||
@@ -126,7 +123,7 @@ class TestIndexing(unittest.TestCase):
|
||||
X = dataset[idxs]
|
||||
assert X.shape == (4,256)
|
||||
sched = X.schedule()
|
||||
assert len(sched) == 2
|
||||
self.assertEqual(len(sched), 2)
|
||||
run_schedule(sched)
|
||||
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
|
||||
np.testing.assert_allclose(real_index, X.numpy())
|
||||
|
||||
@@ -563,8 +563,8 @@ class TestNN(unittest.TestCase):
|
||||
layer.weight.shard_(devices, 3)
|
||||
layer.bias.shard_(devices, None)
|
||||
state_dict = {
|
||||
'weight': Tensor.randn(5, 3, 3, 3),
|
||||
'bias': Tensor.randn(5),
|
||||
'weight': Tensor.randn(5, 3, 3, 3).realize(),
|
||||
'bias': Tensor.randn(5).realize(),
|
||||
}
|
||||
load_state_dict(layer, state_dict)
|
||||
|
||||
|
||||
@@ -615,10 +615,20 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: 2.0**x)
|
||||
helper_test_op([()], lambda x: x**2.0)
|
||||
helper_test_op([()], lambda x: 2.0**x)
|
||||
# TODO: fix backward
|
||||
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
|
||||
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]])
|
||||
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]])
|
||||
|
||||
def test_pow_zero_tensor(self):
|
||||
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.3]])
|
||||
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.0]])
|
||||
# TODO: fix WEBGPU
|
||||
if Device.DEFAULT != "WEBGPU":
|
||||
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [-0.3]])
|
||||
def test_pow_zero_const(self):
|
||||
helper_test_op(None, lambda x: x**0.3, vals=[[0.0]])
|
||||
helper_test_op(None, lambda x: x**0.0, vals=[[0.0]])
|
||||
helper_test_op(None, lambda x: x**-0.3, vals=[[0.0]])
|
||||
|
||||
@unittest.skip("not supported")
|
||||
def test_pow_int(self):
|
||||
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)
|
||||
@@ -641,14 +651,11 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt())
|
||||
if Device.DEFAULT not in ("LLVM", "DSP"):
|
||||
# TODO: fix backward
|
||||
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
|
||||
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
|
||||
helper_test_op([()], lambda x: x.sqrt())
|
||||
def test_rsqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.rsqrt())
|
||||
# TODO: fix backward
|
||||
helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]])
|
||||
helper_test_op([()], lambda x: x.rsqrt())
|
||||
|
||||
def test_xor(self):
|
||||
@@ -1406,7 +1413,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_asinh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6)
|
||||
# NOTE: this one has larger atol
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297)
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_atol=1e-6, low=-300, high=-297)
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303)
|
||||
def test_acosh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6)
|
||||
|
||||
@@ -190,6 +190,10 @@ class TestSpeed(unittest.TestCase):
|
||||
def f(a, b): return a.exp()
|
||||
helper_test_generic_square('exp', 2048, f, f, onearg=True)
|
||||
|
||||
def test_sqrt(self):
|
||||
def f(a, b): return a.sqrt()
|
||||
helper_test_generic_square('sqrt', 2048, f, f, onearg=True)
|
||||
|
||||
def test_relu(self):
|
||||
def f(a, b): return a.relu()
|
||||
helper_test_generic_square('relu', 4096, f, f, onearg=True)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, symbolic
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto
|
||||
|
||||
# NOTE: VIZ tests always use the tracked PatternMatcher instance
|
||||
symbolic = TrackedPatternMatcher(symbolic.patterns)
|
||||
substitute = TrackedPatternMatcher(_substitute.patterns)
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -89,16 +90,18 @@ class TestViz(unittest.TestCase):
|
||||
self.assertEqual(len(uop_to_json(a*1)), 2)
|
||||
self.assertEqual(len(uop_to_json(a*b)), 3)
|
||||
|
||||
@unittest.skip("TODO: bring this back with better testing")
|
||||
def test_bottom_up_rewrite(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
b = UOp.variable("b", 0, 10)
|
||||
c = UOp.variable("c", 0, 10)
|
||||
UOp.substitute(a+b, {a+b:c})
|
||||
@track_rewrites(named=True)
|
||||
def fxn(sink): return graph_rewrite(sink, substitute, ctx={a+b:c}, bottom_up=True)
|
||||
fxn(a+b)
|
||||
#UOp.substitute(a+b, {a+b:c})
|
||||
ret = get_metadata(keys, contexts)
|
||||
self.assertEqual(len(ret), 1)
|
||||
_, _, vals = ret[0][0]
|
||||
self.assertEqual(len(vals.upats), 1)
|
||||
_, m = ret[0]
|
||||
self.assertEqual(m[0]["match_count"], 1)
|
||||
|
||||
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
|
||||
def test_rewrite_without_context(self):
|
||||
|
||||
@@ -129,6 +129,8 @@ powers_of_two = {2**i:i for i in range(64)}
|
||||
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
||||
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
# rewrite SQRT to xpow 0.5
|
||||
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
||||
if Ops.AND in ops:
|
||||
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
||||
|
||||
@@ -22,7 +22,9 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
|
||||
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
|
||||
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
|
||||
(UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))),
|
||||
(UPat(Ops.POW, name="ret"), lambda ctx, ret:
|
||||
(ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]),
|
||||
ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))),
|
||||
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
|
||||
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
|
||||
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
|
||||
|
||||
@@ -841,6 +841,7 @@ match_stats:dict[UPat, list[Union[int, float]]] = dict()
|
||||
class TrackedGraphRewrite:
|
||||
loc: tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink input to graph_rewrite
|
||||
bottom_up: bool
|
||||
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
|
||||
tracked_keys:list[Any] = []
|
||||
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
||||
@@ -925,17 +926,16 @@ class RewriteContext:
|
||||
return ret
|
||||
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
|
||||
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
|
||||
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False) -> dict[UOp, UOp]:
|
||||
if TRACK_MATCH_STATS >= 2 and not bottom_up and len(tracked_ctxs) != 0: # TODO: make viz work with bottom_up=True
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink))
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up))
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
|
||||
|
||||
# *** most of symbolic lives here now ***
|
||||
|
||||
def split_uop(x:UOp, sep:Ops):
|
||||
|
||||
@@ -77,8 +77,6 @@ llvm_rewrite = PatternMatcher([
|
||||
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
|
||||
|
||||
# unary/binary/ternary ops
|
||||
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
|
||||
@@ -152,9 +150,6 @@ class LLVMRenderer(Renderer):
|
||||
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
|
||||
|
||||
for u in uops:
|
||||
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
|
||||
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
|
||||
|
||||
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
||||
# NOTE: MallocAllocator promises 0x20 alignment
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import functools, struct
|
||||
from tinygrad.device import Compiled, Allocator, Compiler
|
||||
from tinygrad.renderer.wgsl import WGSLRenderer
|
||||
from tinygrad.helpers import round_up
|
||||
from tinygrad.helpers import round_up, OSX
|
||||
from tinygrad.runtime.autogen import webgpu
|
||||
from typing import List, Any
|
||||
import ctypes
|
||||
|
||||
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
|
||||
try:
|
||||
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
|
||||
except AttributeError:
|
||||
raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else
|
||||
"sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so"))
|
||||
|
||||
def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8'))
|
||||
|
||||
|
||||
@@ -2652,7 +2652,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([1., 2., 3., 4.]).rsqrt().numpy())
|
||||
```
|
||||
"""
|
||||
return self.reciprocal().sqrt()
|
||||
return self.sqrt().reciprocal()
|
||||
def sin(self):
|
||||
"""
|
||||
Computes the sine of the tensor element-wise.
|
||||
|
||||
Reference in New Issue
Block a user