mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Fix compile error for max with inline const (#5840)
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest
|
||||
from typing import List, cast
|
||||
import numpy as np
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.device import Buffer, CompileError, Device
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.dtype import PtrDType, DType, dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten
|
||||
@@ -36,12 +36,8 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)))
|
||||
store = UOp.store(a, idx, alu)
|
||||
# CLANG doesn't use the max function
|
||||
if Device.DEFAULT in ["CLANG"]:
|
||||
ret = _test_uop_result([Tensor([1])], [store])[0]
|
||||
assert ret[0] == 1
|
||||
# call to 'max' is ambiguous
|
||||
else:
|
||||
with self.assertRaises(CompileError): _test_uop_result([Tensor([1])], [store])
|
||||
ret = _test_uop_result([Tensor([1])], [store])[0]
|
||||
self.assertEqual(ret[0], 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -137,6 +137,7 @@ class CStyleLanguage(Renderer):
|
||||
elif uop is UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
|
||||
elif args is BinaryOps.MAX: operands = [self.render_cast(r[v], cast(DType, v.dtype)) if v.op is UOps.CONST else r[v] for v in src]
|
||||
else: operands = [r[v] for v in src]
|
||||
val = self.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
|
||||
Reference in New Issue
Block a user