mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove custom render in test_simplify_valid_idx (#7303)
use UOp render to compare
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps
|
||||
from tinygrad.ops import UOp, UOps
|
||||
|
||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(UOps.LOAD, dtypes.float, (
|
||||
@@ -22,19 +21,9 @@ def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UO
|
||||
valid
|
||||
))
|
||||
|
||||
def render(uop:UOp) -> str:
|
||||
uops = linearize_uop(full_graph_rewrite(uop.sink()))
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
class TestRenderer(OpenCLRenderer):
|
||||
code_for_op = {**OpenCLRenderer.code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
|
||||
fxn = TestRenderer().render("", uops)
|
||||
# print(fxn)
|
||||
return fxn.split("val0 = ")[1].split(";")[0]
|
||||
|
||||
def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||
def Range(n, nmax):
|
||||
return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
def Range(n, nmax): return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
def test_is_increasing(self):
|
||||
@@ -59,64 +48,12 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertTrue(is_increasing(rng+2))
|
||||
|
||||
class TestValidIdxSimplification(unittest.TestCase):
|
||||
def check(self, val0, sidx, svalid):
|
||||
val0 = full_graph_rewrite(val0.sink()).src[0]
|
||||
idx, valid = val0.src[1], val0.src[3]
|
||||
def check(self, load, sidx, svalid):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
idx, valid = load.src[1], load.src[3]
|
||||
self.assertEqual(idx.render(simplify=False), sidx)
|
||||
self.assertEqual(valid.render(simplify=False), svalid)
|
||||
|
||||
@unittest.skip("need a different way to test conv2d backward")
|
||||
def test_conv_backward(self):
|
||||
# DEBUG=4 python3 test/test_ops.py TestOps.test_simple_conv2d
|
||||
gidx0 = Special("gidx0", 3)
|
||||
gidx1 = Special("gidx1", 3)
|
||||
lidx0 = Special("lidx0", 4)
|
||||
lidx1 = Special("lidx1", 3)
|
||||
lidx2 = Special("lidx2", 3)
|
||||
ridx0 = Range(0, 4)
|
||||
alu0 = gidx0*3
|
||||
alu1 = (alu0+lidx2)
|
||||
alu2 = (gidx1*3)
|
||||
alu3 = (alu1+7)
|
||||
alu4 = (alu1+8)
|
||||
alu5 = (alu1+9)
|
||||
alu6 = ((gidx0+9)//10)
|
||||
alu7 = (alu3%10)
|
||||
alu8 = (alu4%10)
|
||||
alu9 = (alu5%10)
|
||||
alu10 = (gidx1+(ridx0*3))
|
||||
alu11 = (ridx0*9)
|
||||
alu12 = (alu2+lidx1+alu11)
|
||||
alu13 = ((alu6+alu2+lidx1+alu11)%10)
|
||||
alu14 = (alu12%10)
|
||||
alu15 = (((((alu10//10)+lidx0)%4)*441)+(((alu12//10)%3)*3)+(alu14*63))
|
||||
alu16 = alu12.lt(30)
|
||||
alu17 = alu16&(alu14.lt(7))
|
||||
|
||||
# TODO: simplify these
|
||||
val0 = get_gated_load_uop(alu17&(alu9.lt(7)), alu15+(alu5//10)+(alu9*9))
|
||||
self.check(val0,
|
||||
"(((((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63))+((((gidx0*3)+lidx2)+9)//10))+(((((gidx0*3)+lidx2)+9)%10)*9))",
|
||||
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+9)%10)<7))")
|
||||
|
||||
val1 = get_gated_load_uop(
|
||||
((alu16&gidx0.lt(1))&alu13.lt(7))&alu7.lt(7),
|
||||
((((((((((lidx1*10)+gidx0)//3)+3)//10)+alu10)//10)+lidx0)%4)*441)+((((alu6+alu12)//10)%3)*3)+(alu13*63)+(((alu3//10)+2)%3)+(alu7*9)
|
||||
)
|
||||
self.check(val1,
|
||||
"(((lidx2*9)+(((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63)))+65)",
|
||||
"(((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(gidx0<1))&(((((((gidx0+9)//10)+(gidx1*3))+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+7)%10)<7))")
|
||||
|
||||
val2 = get_gated_load_uop(alu17&alu1.lt(7), alu15+(gidx0*27)+(lidx2*9))
|
||||
self.check(val2,
|
||||
"(((((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63))+(gidx0*27))+(lidx2*9))",
|
||||
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((gidx0*3)+lidx2)<7))")
|
||||
|
||||
val3 = get_gated_load_uop(alu17&alu8.lt(7), (alu4//10)+alu15+(alu8*9)+1)
|
||||
self.check(val3,
|
||||
"(((((((gidx0*3)+lidx2)+8)//10)+(((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63)))+(((((gidx0*3)+lidx2)+8)%10)*9))+1)",
|
||||
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+8)%10)<7))")
|
||||
|
||||
def test_cumsum(self):
|
||||
gidx0 = Special("gidx0", 5)
|
||||
lidx0 = Special("lidx0", 4)
|
||||
@@ -140,6 +77,16 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
"((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3)%4)<2))")
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
idx = load.src[1]
|
||||
self.assertEqual(idx.op, UOps.VECTORIZE)
|
||||
self.assertEqual(len(idx.src), 2)
|
||||
idx0, idx1 = idx.src[0], idx.src[1]
|
||||
self.assertEqual(idx0.render(simplify=False), sidx0)
|
||||
self.assertEqual(idx1.render(simplify=False), sidx1)
|
||||
if svalid is not None: self.assertEqual(load.src[3].render(simplify=False), svalid)
|
||||
|
||||
def test_idx_gt_c(self):
|
||||
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
|
||||
# (idx1 < c+1).ne(True) -> idx > c
|
||||
@@ -147,36 +94,33 @@ class TestImageSimplification(unittest.TestCase):
|
||||
gidx1 = Special("gidx1", 32)
|
||||
shape = (10, 10, 4)
|
||||
load = get_load_image_uop(shape, (gidx1).lt(1).ne(True), (gidx0, gidx1-1))
|
||||
self.assertEqual(render(load), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+-1)))")
|
||||
self.check(load, None, "gidx0", "(gidx1+-1)")
|
||||
load = get_load_image_uop(shape, (gidx1).lt(1).ne(True), (gidx0, gidx1-2))
|
||||
self.assertEqual(render(load), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+-2)))")
|
||||
self.check(load, None, "gidx0", "(gidx1+-2)")
|
||||
|
||||
# should match any one of the AND clause and drop the matched statement from valid
|
||||
valid = (gidx0).lt(1).ne(True) & (gidx1).lt(1).ne(True)
|
||||
load = get_load_image_uop(shape, valid, (gidx0+1, gidx1-1))
|
||||
self.assertEqual(render(load),
|
||||
"(((gidx0<1)!=1)?read_imagef(data0, smp, (int2)((gidx0+1),(gidx1+-1))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
self.check(load, "((gidx0<1)!=True)", "(gidx0+1)", "(gidx1+-1)")
|
||||
|
||||
valid = (gidx1).lt(1).ne(True) & (gidx1).lt(1).ne(True)
|
||||
load = get_load_image_uop(shape, valid, (gidx0, gidx1-1))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+-1)))")
|
||||
self.check(load, None, "gidx0", "(gidx1+-1)")
|
||||
|
||||
def test_idx_lt_bound(self):
|
||||
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
load = get_load_image_uop((10, 10, 4), (gidx1).lt(10), (gidx0, gidx1))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
|
||||
self.check(load, None, "gidx0", "gidx1")
|
||||
|
||||
# same thing, valid has a div
|
||||
load = get_load_image_uop((10, 10, 4), (gidx1//2).lt(5), (gidx0, gidx1))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
|
||||
self.check(load, None, "gidx0", "gidx1")
|
||||
|
||||
# 10x20 image, not out of bound
|
||||
load = get_load_image_uop((20, 10, 4), (gidx1).lt(10), (gidx0, gidx1))
|
||||
self.assertEqual(render(load),
|
||||
"((gidx1<10)?read_imagef(data0, smp, (int2)(gidx0,gidx1)):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
self.check(load, "(gidx1<10)", "gidx0", "gidx1")
|
||||
|
||||
def test_generic_idx_lt_bound(self):
|
||||
# (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid
|
||||
@@ -184,11 +128,10 @@ class TestImageSimplification(unittest.TestCase):
|
||||
gidx1 = Special("gidx1", 32)
|
||||
shape = (10, 10, 4)
|
||||
load = get_load_image_uop(shape, (gidx1).lt(8), (gidx0, gidx1+2))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+2)))")
|
||||
self.check(load, None, "gidx0", "(gidx1+2)")
|
||||
|
||||
load = get_load_image_uop(shape, (gidx1).lt(5), (gidx0, gidx1+5))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+5)))")
|
||||
self.check(load, None, "gidx0", "(gidx1+5)")
|
||||
|
||||
def test_valid_empty_set(self):
|
||||
gidx0 = Special("gidx0", 32)
|
||||
@@ -197,12 +140,13 @@ class TestImageSimplification(unittest.TestCase):
|
||||
idx = (gidx0%2, gidx1+2)
|
||||
# not empty
|
||||
load = get_load_image_uop(shape, (gidx0).lt(8), idx)
|
||||
self.assertEqual(render(load),
|
||||
"((gidx0<8)?read_imagef(data0, smp, (int2)((gidx0%2),(gidx1+2))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
self.check(load, "(gidx0<8)", "(gidx0%2)", "(gidx1+2)")
|
||||
|
||||
# empty
|
||||
# empty -> invalid
|
||||
load = get_load_image_uop(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx)
|
||||
self.assertRaises(IndexError, lambda: render(load))
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
self.assertEqual(load.op, UOps.VECTORIZE)
|
||||
self.assertEqual(load.dtype.count, 4)
|
||||
|
||||
def test_openpilot_conv1(self):
|
||||
# first conv in openpilot
|
||||
@@ -224,8 +168,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
idx = ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
|
||||
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(((idx1*48)+(ridx2*6)+ridx0+-6),((idx2*2)+ridx1+-1)))")
|
||||
self.check(load, None, "((((idx1*48)+(ridx2*6))+ridx0)+-6)", "(((idx2*2)+ridx1)+-1)")
|
||||
|
||||
def test_openpilot_conv2(self):
|
||||
# conv in test/external/external_test_valid_remove.py
|
||||
@@ -246,8 +189,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
idx = ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(((idx1*24)+(ridx2*3)+ridx0+-3),((idx2*2)+ridx1+-1)))")
|
||||
self.check(load, None, "((((idx1*24)+(ridx2*3))+ridx0)+-3)", "(((idx2*2)+ridx1)+-1)")
|
||||
|
||||
def test_openpilot_conv3(self):
|
||||
# in openpilot 0.9.7
|
||||
@@ -266,10 +208,11 @@ class TestImageSimplification(unittest.TestCase):
|
||||
idx = (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)//8)+1)//2)+(-4)))
|
||||
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
# TODO: simplify idx
|
||||
# alu0 = ((idx2*2)+ridx0)
|
||||
self.assertEqual(render(load),
|
||||
"(((alu0<11)&((((idx1*8)+ridx1)<3)!=1))?read_imagef(data0, smp, (int2)(((idx0+(idx1*512)+(ridx1*64)+832)%1024),(alu0+((idx1+((ridx1+5)//8)+1)//2)+-4))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
|
||||
|
||||
self.check(load,
|
||||
"((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True))",
|
||||
"(((idx0+((idx1*512)+(ridx1*64)))+832)%1024)",
|
||||
"((((idx2*2)+ridx0)+(((idx1+((ridx1+5)//8))+1)//2))+-4)")
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
@@ -277,9 +220,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
valid = gidx.lt(488) & (gidx).lt(480).ne(True)
|
||||
idx = ((gidx*3+18)%26, (gidx*3+18)//26-56)
|
||||
load = get_load_image_uop((1, 26, 4), valid, idx)
|
||||
# alu0 is ((gidx*3)+18)
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(((gidx*3)+-1438),0))")
|
||||
self.check(load, None, "((gidx*3)+-1438)", "0")
|
||||
|
||||
def test_simplify2(self):
|
||||
# from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
|
||||
@@ -287,8 +228,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
valid = lidx.lt(3) & lidx.lt(1).ne(True)
|
||||
idx = ((lidx+1)%2, (lidx+1)//2-1)
|
||||
load = get_load_image_uop((1, 2, 4), valid, idx)
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)((lidx+-1),0))")
|
||||
self.check(load, None, "(lidx+-1)", "0")
|
||||
|
||||
def test_simplify3(self):
|
||||
# from openpilot
|
||||
@@ -296,8 +236,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
valid = idx0.lt(201).ne(True)
|
||||
idx = ((idx0+55)%64, (idx0+55)//64-4)
|
||||
load = get_load_image_uop((1, 64, 4), valid, idx)
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)((idx0+-201),0))")
|
||||
self.check(load, None, "(idx0+-201)", "0")
|
||||
|
||||
def test_simplify4(self):
|
||||
idx0 = Special("idx0", 512)
|
||||
@@ -311,24 +250,16 @@ class TestImageSimplification(unittest.TestCase):
|
||||
|
||||
# TODO: can this be simplified further?
|
||||
load = get_load_image_uop(shape, alu9, (((alu8+(alu2*8))%64),(alu2//8)))
|
||||
# alu0 = (((idx0*4)+1)%32)
|
||||
self.assertEqual(render(load),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((alu0*8)+(idx0//32))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
self.check(load, "(idx0<256)", "((((((idx0*4)+1)%32)*8)+(idx0//32))%64)", "((((idx0*4)+1)%32)//8)")
|
||||
|
||||
load = get_load_image_uop(shape, alu9, (((alu8+(alu3*8))%64),(alu3//8)))
|
||||
# alu0 = (((idx0*4)+2)%32)
|
||||
self.assertEqual(render(load),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((alu0*8)+(idx0//32))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
self.check(load, "(idx0<256)", "((((((idx0*4)+2)%32)*8)+(idx0//32))%64)", "((((idx0*4)+2)%32)//8)")
|
||||
|
||||
load = get_load_image_uop(shape, alu9, (((alu8+(alu4*8))%64),(alu4//8)))
|
||||
# alu0 = (((idx0*4)+3)%32)
|
||||
self.assertEqual(render(load),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((alu0*8)+(idx0//32))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
self.check(load, "(idx0<256)", "((((((idx0*4)+3)%32)*8)+(idx0//32))%64)", "((((idx0*4)+3)%32)//8)")
|
||||
|
||||
load = get_load_image_uop(shape, alu9, (((alu8+(alu5*8))%64),(alu5//8)))
|
||||
# alu0 = ((idx0*4)%32)
|
||||
self.assertEqual(render(load),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((alu0*8)+(idx0//32))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
self.check(load, "(idx0<256)", "(((((idx0*4)%32)*8)+(idx0//32))%64)", "(((idx0*4)%32)//8)")
|
||||
|
||||
def test_simplify5(self):
|
||||
# openpilot 0.9.7, chunk replacement to simplify
|
||||
@@ -343,8 +274,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
valid = alu3.lt(640)
|
||||
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
self.assertEqual(render(load),
|
||||
"((alu0<640)?read_imagef(data0, smp, (int2)((idx0+((idx1//3)*16)+128),(alu0//64))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
self.check(load, "(((((idx0*4)+(idx1*256))+1)%768)<640)", "((idx0+((idx1//3)*16))+128)", "(((((idx0*4)+(idx1*256))+1)%768)//64)")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user