mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -277,7 +277,7 @@ if __name__ == "__main__":
|
||||
elif args.beamreplay is not None:
|
||||
print(f"loading BEAM replay from file '{args.beamreplay}'")
|
||||
with open(args.beamreplay, 'r') as file: fdata = file.readlines()
|
||||
ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata], [x.split(' :: ')[1] for x in fdata]
|
||||
ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata if not x.startswith("#")], [x.split(' :: ')[1] for x in fdata if not x.startswith("#")]
|
||||
|
||||
# dedup ast_strs and opts_list
|
||||
dct = defaultdict(list)
|
||||
|
||||
@@ -1317,5 +1317,21 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI")
|
||||
def test_failure_55(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Tuple
|
||||
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.ops import UOp, UOps, simplify_valid
|
||||
|
||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(UOps.LOAD, dtypes.float, (
|
||||
@@ -64,7 +64,7 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
"0",
|
||||
"(((lidx0+(gidx0*4))<19)!=True)")
|
||||
|
||||
def test_simplify_within_valid(self):
|
||||
def test_simplify_within_valid1(self):
|
||||
ridx0 = Range(0, 4)
|
||||
ridx1 = Range(1, 4)
|
||||
ridx2 = Range(2, 4)
|
||||
@@ -76,6 +76,13 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
"(((ridx0+ridx1)+ridx2)+ridx3)",
|
||||
"((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3)%4)<2))")
|
||||
|
||||
def test_simplify_within_valid2(self):
|
||||
gidx0 = Special("gidx0", 56)
|
||||
ridx0 = Range(0, 3)
|
||||
alu0 = gidx0+ridx0
|
||||
valid = alu0.lt(57) & alu0.ge(1)
|
||||
self.assertIsNone(simplify_valid(valid))
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
|
||||
@@ -156,7 +156,10 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
||||
with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
|
||||
seen_libs.add(lib)
|
||||
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
||||
except RuntimeError: continue # for runtime issues
|
||||
except RuntimeError as e:
|
||||
if len(CAPTURE_BEAM) > 0:
|
||||
with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n")
|
||||
continue # for runtime issues
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(List, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501
|
||||
elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
@@ -944,7 +944,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
|
||||
Reference in New Issue
Block a user