mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
unwrap instead of cast [pr] (#12982)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import unittest
|
||||
from typing import List, cast
|
||||
import numpy as np
|
||||
from tinygrad.device import Buffer, Device, is_dtype_supported
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
@@ -15,15 +14,15 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
|
||||
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
|
||||
for x in inputs: x.realize()
|
||||
# NOTE: we only toposort the stores
|
||||
uops: List[UOp] = []
|
||||
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
|
||||
uops: list[UOp] = []
|
||||
def _recursive_add(uop:UOp) -> list[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
|
||||
uops = dedup(flatten(_recursive_add(st) for st in stores))
|
||||
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [cast(UOp,x.uop).base.buffer for x in inputs]
|
||||
inbufs = [x.uop.base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test",
|
||||
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
|
||||
|
||||
Reference in New Issue
Block a user