mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Add RDNA3 assembler UOps.CAST partial support + other fixes/improvements (#1012)
* Add support for one case of `UOps.CAST` for RDNA3 assembler * Adds support for casting from `bool` -> `float32`. Seems like a very common operation that is required in many places. * Fix bool register definition for vector operations * Use `vcc_lo` instead of `vcc` which seems to be required since it's configured to use wavefront_size=32 * Add vector support for some places that were scalar only in register definition and comparison ops * Fix some issues in what seems to be defunct `external_test_image.py` * Some tests still don't pass for other reasons, but it at least runs now and one broken test is now fixed * Refactor RDNA3 assembler register definition * Unify multi-registor code between dtypes and combine with single-register allocation since they're all untyped registers at the end of the day
This commit is contained in:
6
test/external/external_test_image.py
vendored
6
test/external/external_test_image.py
vendored
@@ -7,7 +7,6 @@ if 'IMAGE' not in os.environ:
|
||||
os.environ['GPU'] = '1'
|
||||
os.environ['OPT'] = '2'
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.runtime.ops_gpu import CLImage
|
||||
from tinygrad.nn import Conv2d
|
||||
Tensor.no_grad = True
|
||||
|
||||
@@ -16,16 +15,14 @@ class TestImage(unittest.TestCase):
|
||||
t = Tensor.ones(128, 128, 1)
|
||||
t = t.reshape(128, 32, 4) + 3
|
||||
t.realize()
|
||||
assert isinstance(t.lazydata.realized._buf, CLImage)
|
||||
np.testing.assert_array_equal(t.numpy(), np.ones((128,32,4))*4)
|
||||
|
||||
def test_sum_image(self):
|
||||
t1 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 3
|
||||
t1.realize()
|
||||
assert isinstance(t1.lazydata.realized._buf, CLImage)
|
||||
t1 = t1.sum()
|
||||
t1.realize()
|
||||
assert t1.numpy()[0] == 16*4*4*4, f"got {t1.numpy()}"
|
||||
assert t1.numpy() == 16*4*4*4, f"got {t1.numpy()}"
|
||||
|
||||
def test_add_image(self):
|
||||
t1 = Tensor.ones(16, 16, 1).reshape(16, 4, 4) + 3
|
||||
@@ -34,7 +31,6 @@ class TestImage(unittest.TestCase):
|
||||
t2.realize()
|
||||
t3 = t1 + t2
|
||||
t3.realize()
|
||||
assert isinstance(t3.lazydata.realized._buf, CLImage)
|
||||
np.testing.assert_array_equal(t3.numpy(), np.ones((16,4,4))*9)
|
||||
|
||||
def test_padded_conv(self):
|
||||
|
||||
@@ -62,31 +62,30 @@ class RDNACodegen(AssemblyCodegen):
|
||||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if arg[0][0] == dtypes.uint64 and arg[0][1]:
|
||||
# assuming these are scalar
|
||||
s_cnt += s_cnt%2 # aligned(2)
|
||||
for i in range(arg[2]):
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s[{s_cnt}:{s_cnt+1}]"
|
||||
s_cnt += 2
|
||||
elif arg[0][0] == dtypes._float4 and not arg[0][1]:
|
||||
v_cnt += (4-v_cnt%4) if v_cnt%4 != 0 else 0
|
||||
for i in range(arg[2]):
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v[{v_cnt}:{v_cnt+3}]"
|
||||
for off in range(4): rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = f"v{v_cnt+off}"
|
||||
v_cnt += 4
|
||||
elif arg[0][0] in [dtypes.int32, dtypes.float32]:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float64, dtypes._float4]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
align = int(arg[0][0].itemsize / 4)
|
||||
if arg[0][1]:
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"s{s_cnt}"
|
||||
s_cnt += 1
|
||||
s_cnt += s_cnt % align
|
||||
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
|
||||
s_cnt += align
|
||||
else:
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = f"v{v_cnt}"
|
||||
v_cnt += 1
|
||||
elif arg[0][0] == dtypes.bool and arg[0][1]:
|
||||
v_cnt += v_cnt % align
|
||||
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
|
||||
v_cnt += align
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
|
||||
if arg[0][0] == dtypes._float4:
|
||||
for off in range(4):
|
||||
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
|
||||
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
|
||||
elif arg[0][0] == dtypes.bool:
|
||||
for i in range(arg[2]):
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = "scc" if arg[0][1] else "vcc"
|
||||
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
else:
|
||||
raise NotImplementedError(arg)
|
||||
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
i = int(arg[3:])
|
||||
@@ -116,11 +115,8 @@ class RDNACodegen(AssemblyCodegen):
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
||||
elif uop == UOps.ALU:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if out.scalar:
|
||||
ins.append(f"s_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
ins.append(f"v_cmp_lt_{dtype_to_rdnatype[out.dtype]} vcc, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
if arg in [BinaryOps.CMPLT, BinaryOps.CMPEQ]:
|
||||
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
alu_arg = alu[arg]
|
||||
if arg == FusedOps.MULACC and out == vin[2]:
|
||||
@@ -145,6 +141,12 @@ class RDNACodegen(AssemblyCodegen):
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
||||
elif uop == UOps.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
if out.dtype == dtypes.float32:
|
||||
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
||||
else:
|
||||
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
||||
else:
|
||||
raise NotImplementedError(uop)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user