refactor UOps.CONST (#4639)

* delete more

* nit: dont need assign

* can this be simpler

* use scalars

* always cast

* clang needs cast

* format
This commit is contained in:
qazal
2024-05-18 15:07:36 +08:00
committed by GitHub
parent d0a2d40df3
commit a5204fe89d
3 changed files with 9 additions and 12 deletions

View File

@@ -11,7 +11,7 @@ from tinygrad.tensor import Tensor
ctx_vars = { MULTIOUTPUT: (0, 1) }
def fuzz_schedule(outs: List[LazyBuffer]):
def fuzz_schedule(outs:List[LazyBuffer]):
# find toposorts across all tunable params
unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, _LBScheduleItem]]] = {}
for combination in itertools.product(*ctx_vars.values()):
@@ -66,7 +66,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
print(f"FAILED FOR {out}")
raise e
def _exec_si(si: ScheduleItem, seed:int):
def _exec_si(si:ScheduleItem, seed:int):
ei = lower_schedule_item(si)
if len(capturing): capturing[0].add(ei)
if isinstance(ei.prg, CustomOp): Tensor._seed = seed

View File

@@ -52,16 +52,13 @@ class Linearizer(Kernel):
# NOTE: the consts have to be cached for deduping of downstream uops to work
def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
if isinstance(b, Variable): return self.uops.add(UOps.DEFINE_VAR, dtype, tuple(), b.unbind()[0])
else: return self.uops.add(UOps.CONST, dtype, tuple(), dtypes.as_const(b, dtype))
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
return self.uops.add(UOps.CONST, dtype, tuple(), dtypes.as_const(b, dtype))
def get_reduce_acc(self, reduceop:LazyOp):
dtype = reduceop.dtype
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
elif reduceop.op is ReduceOps.MAX:
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
return -math.inf if dtypes.is_float(dtype) else False
if dtypes.is_int(reduceop.dtype): return 0 if dtypes.is_unsigned(reduceop.dtype) else -2**(reduceop.dtype.itemsize*8-1)
return -math.inf if dtypes.is_float(reduceop.dtype) else False
# NOTE: once images are loaded, we uop them as their base float
def get_base_dtype(self, dt:DType): return dt.base if isinstance(dt, ImageDType) else dt

View File

@@ -43,9 +43,9 @@ class CStyleLanguage(Renderer):
def render_const(self, x:ConstType, dtype:DType) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
elif dtype == dtypes.float64: val = f"{x}"
elif dtype == dtypes.bool: val = "1" if x else "0"
else: val = f"{x}f" if dtypes.is_float(dtype) else f"{x}"
elif dtype is dtypes.bool: val = "1" if x else "0"
elif dtype is dtypes.float: val = f"{x}f"
else: val = str(x)
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
# returns a str expression of the loaded value with the output type