UOps.BITCAST (#3747)

* UOps.BITCAST

implicitly fixed no const folding for bitcast

* python backend

* ptx

* consistent llvm
This commit is contained in:
chenyu
2024-03-14 21:00:35 -04:00
committed by GitHub
parent 9a00a453c7
commit 75d4344cda
8 changed files with 13 additions and 15 deletions

View File

@@ -137,14 +137,14 @@ class TestBFloat16(unittest.TestCase):
assert tnp.dtype == np.float32
np.testing.assert_allclose(tnp, np.array(data))
@unittest.expectedFailure
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
def test_bf16_ones(self):
# TODO: fix this with correct bfloat16 cast
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
@unittest.expectedFailure
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
def test_bf16_eye(self):
# TODO: fix this with correct bfloat16 cast
t = Tensor.eye(3, dtype=dtypes.bfloat16)

View File

@@ -186,15 +186,13 @@ class TestConstantFolding(unittest.TestCase):
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
assert all(uop.uop is not UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} contains non-folded constant cast"
@unittest.expectedFailure
def test_bitcast_const(self):
# TODO: fix bitcast const should not fold
t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int)
si = create_schedule([t.lazydata])
assert len(si) == 1
si = si[0]
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
assert any(uop.uop is UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
assert any(uop.uop is UOps.BITCAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -410,8 +410,8 @@ class Linearizer(Kernel):
if cache is None: cache = {}
if x in cache: return cache[x]
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op == UnaryOps.CAST:
return [self.uops.add(UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
if x.op == UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \
for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
if x.op in ReduceOps and not do_reduce:
assert offs is None, "not available if we aren't doing reduce"
return acc

View File

@@ -14,7 +14,7 @@ class UOps(Enum):
LOOP = auto(); IF = auto(); ENDLOOP = auto(); ENDIF = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); BITCAST = auto(); GEP = auto() # noqa: E702
@dataclass(eq=False)
class UOp:

View File

@@ -149,9 +149,9 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
elif uop == UOps.PHI:
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop == UOps.CAST:
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u)
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop == UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"

View File

@@ -145,8 +145,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
elif uop is UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop is UOps.CAST:
if isinstance(args, tuple) and args[1]: # bitcast
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
precast = ssa(None,'precast')
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")

View File

@@ -144,7 +144,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
elif uop is UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)

View File

@@ -104,13 +104,13 @@ class PythonProgram:
del ul[i]
i = loop_ends[i] + 1
continue
elif uop is UOps.CAST:
elif uop in {UOps.CAST, UOps.BITCAST}:
if dtype.count > 1:
ul[i] = inp
else:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
if isinstance(arg, tuple) and arg[1]:
if uop is UOps.BITCAST:
ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
else:
casted = [float(x) if dtypes.is_float(dtype) else int(x) if dtypes.is_int(dtype) else x for x in inp[0]]