unwrap instead of cast [pr] (#12982)

This commit is contained in:
chenyu
2025-10-28 21:29:23 -04:00
committed by GitHub
parent f55fcfecf9
commit ef16e6c68c
3 changed files with 7 additions and 9 deletions

View File

@@ -1,5 +1,4 @@
import unittest
from typing import List, cast
import numpy as np
from tinygrad.device import Buffer, Device, is_dtype_supported
from tinygrad.dtype import dtypes, ConstType
@@ -15,15 +14,15 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.codegen import full_rewrite
from tinygrad.engine.realize import lower_schedule_item
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
for x in inputs: x.realize()
# NOTE: we only toposort the stores
uops: List[UOp] = []
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops: list[UOp] = []
def _recursive_add(uop:UOp) -> list[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops = dedup(flatten(_recursive_add(st) for st in stores))
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [cast(UOp,x.uop).base.buffer for x in inputs]
inbufs = [x.uop.base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render(uops)
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test",
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))