mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove SQRT hack in llvm (#9067)
replaced with xpow 0.5 in transcendental. fixed sqrt(0) backward
This commit is contained in:
@@ -1054,7 +1054,7 @@ class TestIndexing(unittest.TestCase):
|
||||
one = Tensor(1, dtype=dtypes.int64)
|
||||
|
||||
# non-scalar indexed with scalars
|
||||
a = Tensor.randn(2, 3)
|
||||
a = Tensor.randn(2, 3).realize()
|
||||
numpy_testing_assert_equal_helper(a[0], a[zero])
|
||||
numpy_testing_assert_equal_helper(a[0][1], a[zero][one])
|
||||
numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
|
||||
@@ -1066,7 +1066,7 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])
|
||||
|
||||
# scalar indexed with scalar
|
||||
r = Tensor.randn()
|
||||
r = Tensor.randn().realize()
|
||||
with self.assertRaises(IndexError):
|
||||
r[:]
|
||||
with self.assertRaises(IndexError):
|
||||
|
||||
@@ -563,8 +563,8 @@ class TestNN(unittest.TestCase):
|
||||
layer.weight.shard_(devices, 3)
|
||||
layer.bias.shard_(devices, None)
|
||||
state_dict = {
|
||||
'weight': Tensor.randn(5, 3, 3, 3),
|
||||
'bias': Tensor.randn(5),
|
||||
'weight': Tensor.randn(5, 3, 3, 3).realize(),
|
||||
'bias': Tensor.randn(5).realize(),
|
||||
}
|
||||
load_state_dict(layer, state_dict)
|
||||
|
||||
|
||||
@@ -641,9 +641,7 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt())
|
||||
if Device.DEFAULT not in ("LLVM", "DSP"):
|
||||
# TODO: fix backward
|
||||
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
|
||||
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
|
||||
helper_test_op([()], lambda x: x.sqrt())
|
||||
def test_rsqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.rsqrt())
|
||||
@@ -1406,7 +1404,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_asinh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6)
|
||||
# NOTE: this one has larger atol
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297)
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_atol=1e-6, low=-300, high=-297)
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303)
|
||||
def test_acosh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6)
|
||||
|
||||
@@ -129,6 +129,8 @@ powers_of_two = {2**i:i for i in range(64)}
|
||||
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
||||
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
# rewrite SQRT to xpow 0.5
|
||||
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
||||
if Ops.AND in ops:
|
||||
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
||||
|
||||
@@ -77,8 +77,6 @@ llvm_rewrite = PatternMatcher([
|
||||
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),
|
||||
|
||||
# unary/binary/ternary ops
|
||||
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
|
||||
@@ -152,9 +150,6 @@ class LLVMRenderer(Renderer):
|
||||
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]
|
||||
|
||||
for u in uops:
|
||||
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
|
||||
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
|
||||
|
||||
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
||||
# NOTE: MallocAllocator promises 0x20 alignment
|
||||
|
||||
Reference in New Issue
Block a user