mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
UOps.BITCAST (#3747)
* UOps.BITCAST implicitly fixed no const folding for bitcast * python backend * ptx * consistent llvm
This commit is contained in:
@@ -137,14 +137,14 @@ class TestBFloat16(unittest.TestCase):
|
||||
assert tnp.dtype == np.float32
|
||||
np.testing.assert_allclose(tnp, np.array(data))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
|
||||
def test_bf16_ones(self):
|
||||
# TODO: fix this with correct bfloat16 cast
|
||||
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
|
||||
def test_bf16_eye(self):
|
||||
# TODO: fix this with correct bfloat16 cast
|
||||
t = Tensor.eye(3, dtype=dtypes.bfloat16)
|
||||
|
||||
@@ -186,15 +186,13 @@ class TestConstantFolding(unittest.TestCase):
|
||||
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
|
||||
assert all(uop.uop is not UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} contains non-folded constant cast"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_bitcast_const(self):
|
||||
# TODO: fix bitcast const should not fold
|
||||
t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int)
|
||||
si = create_schedule([t.lazydata])
|
||||
assert len(si) == 1
|
||||
si = si[0]
|
||||
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
|
||||
assert any(uop.uop is UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
|
||||
assert any(uop.uop is UOps.BITCAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -410,8 +410,8 @@ class Linearizer(Kernel):
|
||||
if cache is None: cache = {}
|
||||
if x in cache: return cache[x]
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op == UnaryOps.CAST:
|
||||
return [self.uops.add(UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
if x.op == UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \
|
||||
for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
|
||||
@@ -14,7 +14,7 @@ class UOps(Enum):
|
||||
LOOP = auto(); IF = auto(); ENDLOOP = auto(); ENDIF = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
||||
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); BITCAST = auto(); GEP = auto() # noqa: E702
|
||||
|
||||
@dataclass(eq=False)
|
||||
class UOp:
|
||||
|
||||
@@ -149,9 +149,9 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
elif uop == UOps.PHI:
|
||||
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop == UOps.CAST:
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u)
|
||||
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
|
||||
@@ -145,8 +145,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop is UOps.CAST:
|
||||
if isinstance(args, tuple) and args[1]: # bitcast
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if uop is UOps.BITCAST:
|
||||
assert len(vin) == 1
|
||||
precast = ssa(None,'precast')
|
||||
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
|
||||
|
||||
@@ -144,7 +144,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
||||
elif uop is UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
||||
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
|
||||
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
||||
|
||||
@@ -104,13 +104,13 @@ class PythonProgram:
|
||||
del ul[i]
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop is UOps.CAST:
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if dtype.count > 1:
|
||||
ul[i] = inp
|
||||
else:
|
||||
assert dtp[0].fmt and dtype.fmt
|
||||
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
||||
if isinstance(arg, tuple) and arg[1]:
|
||||
if uop is UOps.BITCAST:
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
else:
|
||||
casted = [float(x) if dtypes.is_float(dtype) else int(x) if dtypes.is_int(dtype) else x for x in inp[0]]
|
||||
|
||||
Reference in New Issue
Block a user