UOps.BITCAST (#3747)

* UOps.BITCAST

implicitly fixed no const folding for bitcast

* python backend

* ptx

* consistent llvm
This commit is contained in:
chenyu
2024-03-14 21:00:35 -04:00
committed by GitHub
parent 9a00a453c7
commit 75d4344cda
8 changed files with 13 additions and 15 deletions

View File

@@ -149,9 +149,9 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
elif uop == UOps.PHI:
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop == UOps.CAST:
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u)
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop == UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"

View File

@@ -145,8 +145,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
elif uop is UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop is UOps.CAST:
if isinstance(args, tuple) and args[1]: # bitcast
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
precast = ssa(None,'precast')
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")

View File

@@ -144,7 +144,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
elif uop is UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1])
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)