mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
refactor UOps.CONST (#4639)
* delete more * nit: dont need assign * can this be simpler * use scalars * always cast * clang needs cast * format
This commit is contained in:
4
test/external/fuzz_schedule.py
vendored
4
test/external/fuzz_schedule.py
vendored
@@ -11,7 +11,7 @@ from tinygrad.tensor import Tensor
|
||||
|
||||
ctx_vars = { MULTIOUTPUT: (0, 1) }
|
||||
|
||||
def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
# find toposorts across all tunable params
|
||||
unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, _LBScheduleItem]]] = {}
|
||||
for combination in itertools.product(*ctx_vars.values()):
|
||||
@@ -66,7 +66,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
print(f"FAILED FOR {out}")
|
||||
raise e
|
||||
|
||||
def _exec_si(si: ScheduleItem, seed:int):
|
||||
def _exec_si(si:ScheduleItem, seed:int):
|
||||
ei = lower_schedule_item(si)
|
||||
if len(capturing): capturing[0].add(ei)
|
||||
if isinstance(ei.prg, CustomOp): Tensor._seed = seed
|
||||
|
||||
@@ -52,16 +52,13 @@ class Linearizer(Kernel):
|
||||
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
||||
def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
|
||||
if isinstance(b, Variable): return self.uops.add(UOps.DEFINE_VAR, dtype, tuple(), b.unbind()[0])
|
||||
else: return self.uops.add(UOps.CONST, dtype, tuple(), dtypes.as_const(b, dtype))
|
||||
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
return self.uops.add(UOps.CONST, dtype, tuple(), dtypes.as_const(b, dtype))
|
||||
|
||||
def get_reduce_acc(self, reduceop:LazyOp):
|
||||
dtype = reduceop.dtype
|
||||
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
|
||||
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
|
||||
elif reduceop.op is ReduceOps.MAX:
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
||||
return -math.inf if dtypes.is_float(dtype) else False
|
||||
if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
|
||||
return -math.inf if dtypes.is_float(reduceop.dtype) else False
|
||||
|
||||
# NOTE: once images are loaded, we uop them as their base float
|
||||
def get_base_dtype(self, dt:DType): return dt.base if isinstance(dt, ImageDType) else dt
|
||||
|
||||
@@ -43,9 +43,9 @@ class CStyleLanguage(Renderer):
|
||||
def render_const(self, x:ConstType, dtype:DType) -> str:
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
elif dtype == dtypes.float64: val = f"{x}"
|
||||
elif dtype == dtypes.bool: val = "1" if x else "0"
|
||||
else: val = f"{x}f" if dtypes.is_float(dtype) else f"{x}"
|
||||
elif dtype is dtypes.bool: val = "1" if x else "0"
|
||||
elif dtype is dtypes.float: val = f"{x}f"
|
||||
else: val = str(x)
|
||||
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
|
||||
Reference in New Issue
Block a user