Add RDNA3 assembler UOps.CAST partial support + other fixes/improvements (#1012)

* Add support for one case of `UOps.CAST` for RDNA3 assembler

 * Adds support for casting from `bool` -> `float32`.  Seems like a very common operation that is required in many places.
 * Fix bool register definition for vector operations
   * Use `vcc_lo` instead of `vcc` which seems to be required since it's configured to use wavefront_size=32
 * Add vector support for some places that were scalar only in register definition and comparison ops
 * Fix some issues in what seems to be defunct `external_test_image.py`
   * Some tests still don't pass for other reasons, but it at least runs now and one broken test is now fixed

* Refactor RDNA3 assembler register definition

 * Unify multi-registor code between dtypes and combine with single-register allocation since they're all untyped registers at the end of the day
This commit is contained in:
Casey Primozic
2023-06-20 11:34:10 -07:00
committed by GitHub
parent 57d3aa76a5
commit aab9ee0fca
2 changed files with 28 additions and 30 deletions

View File

@@ -7,7 +7,6 @@ if 'IMAGE' not in os.environ:
os.environ['GPU'] = '1'
os.environ['OPT'] = '2'
from tinygrad.tensor import Tensor
from tinygrad.runtime.ops_gpu import CLImage
from tinygrad.nn import Conv2d
Tensor.no_grad = True
@@ -16,16 +15,14 @@ class TestImage(unittest.TestCase):
t = Tensor.ones(128, 128, 1)
t = t.reshape(128, 32, 4) + 3
t.realize()
assert isinstance(t.lazydata.realized._buf, CLImage)
np.testing.assert_array_equal(t.numpy(), np.ones((128,32,4))*4)
def test_sum_image(self):
t1 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 3
t1.realize()
assert isinstance(t1.lazydata.realized._buf, CLImage)
t1 = t1.sum()
t1.realize()
assert t1.numpy()[0] == 16*4*4*4, f"got {t1.numpy()}"
assert t1.numpy() == 16*4*4*4, f"got {t1.numpy()}"
def test_add_image(self):
t1 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 3
@@ -34,7 +31,6 @@ class TestImage(unittest.TestCase):
t2.realize()
t3 = t1 + t2
t3.realize()
assert isinstance(t3.lazydata.realized._buf, CLImage)
np.testing.assert_array_equal(t3.numpy(), np.ones((16,4,4))*9)
def test_padded_conv(self):

View File

@@ -62,31 +62,30 @@ class RDNACodegen(AssemblyCodegen):
return rtor[x]
for uop, out, vin, arg in asm:
if uop == UOps.DEFINE_REGISTER:
if arg[0][0] == dtypes.uint64 and arg[0][1]:
# assuming these are scalar
s_cnt += s_cnt%2 # aligned(2)
for i in range(arg[2]):
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s[{s_cnt}:{s_cnt+1}]"
s_cnt += 2
elif arg[0][0] == dtypes._float4 and not arg[0][1]:
v_cnt += (4-v_cnt%4) if v_cnt%4 != 0 else 0
for i in range(arg[2]):
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v[{v_cnt}:{v_cnt+3}]"
for off in range(4): rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = f"v{v_cnt+off}"
v_cnt += 4
elif arg[0][0] in [dtypes.int32, dtypes.float32]:
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float64, dtypes._float4]:
for i in range(arg[2]):
# TODO: Re-use gaps created by this to avoid wasting registers
align = int(arg[0][0].itemsize / 4)
if arg[0][1]:
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s{s_cnt}"
s_cnt += 1
s_cnt += s_cnt % align
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
s_cnt += align
else:
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v{v_cnt}"
v_cnt += 1
elif arg[0][0] == dtypes.bool and arg[0][1]:
v_cnt += v_cnt % align
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
v_cnt += align
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
if arg[0][0] == dtypes._float4:
for off in range(4):
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
elif arg[0][0] == dtypes.bool:
for i in range(arg[2]):
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = "scc" if arg[0][1] else "vcc"
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
else:
raise NotImplementedError(arg)
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
elif uop == UOps.SPECIAL:
if arg.startswith('buf'):
i = int(arg[3:])
@@ -116,11 +115,8 @@ class RDNACodegen(AssemblyCodegen):
else:
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
elif uop == UOps.ALU:
if arg == BinaryOps.CMPLT:
if out.scalar:
ins.append(f"s_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
else:
ins.append(f"v_cmp_lt_{dtype_to_rdnatype[out.dtype]} vcc, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
if arg in [BinaryOps.CMPLT, BinaryOps.CMPEQ]:
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
else:
alu_arg = alu[arg]
if arg == FusedOps.MULACC and out == vin[2]:
@@ -145,6 +141,12 @@ class RDNACodegen(AssemblyCodegen):
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
elif uop == UOps.CAST:
if vin[0].dtype == dtypes.bool:
if out.dtype == dtypes.float32:
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
else:
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
else:
raise NotImplementedError(uop)