fix bufferize cost function for multi, improve VIZ=-1 cli (#14394)

* improve cli

* remove_bufferize change
This commit is contained in:
qazal
2026-01-28 01:53:18 -05:00
committed by GitHub
parent c158acea29
commit 0294014108
3 changed files with 5 additions and 2 deletions

View File

@@ -10,7 +10,11 @@ def optional_eq(val:dict, arg:str|None) -> bool: return arg is None or ansistrip
def print_data(data:dict) -> None:
if isinstance(data.get("value"), Iterator):
for m in data["value"]:
if m.get("uop"):
print("Input UOp:")
print(m["uop"])
if not m["diff"]: continue
print("Rewrites:")
fp = pathlib.Path(m["upat"][0][0])
print(f"{fp.parent.name}/{fp.name}:{m['upat'][0][1]}")
print(m["upat"][1])

View File

@@ -1263,7 +1263,6 @@ class TestMultiRamUsage(unittest.TestCase):
self.assertEqual(total_mem[dtypes.half], total_mem[dtypes.float] // 2)
def test_matmul_half(self): self._test_matmul_half(devices_2)
@unittest.expectedFailure
def test_matmul_half_alt(self): self._test_matmul_half(devices_4)
@unittest.skipIf(not_support_multi_device(), "need multi")

View File

@@ -171,7 +171,7 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
indexes: list[UOp] = []
reduces: list[UOp] = []
def red_gate(x:UOp):
if x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL:
if (x.op is Ops.BUFFERIZE and x.arg.addrspace == AddrSpace.GLOBAL) or x.op is Ops.MSTACK:
accessed_buffers.append(x)
return False
if x.op is Ops.BUFFER: