mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
tk softmax (#13205)
This commit is contained in:
64
test/external/external_test_tk.py
vendored
64
test/external/external_test_tk.py
vendored
@@ -1,4 +1,4 @@
|
||||
import unittest
|
||||
import unittest, math
|
||||
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.engine.realize import ExecItem, get_runner
|
||||
@@ -335,5 +335,67 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(b.numpy(), ref.numpy(), atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_softmax(self):
|
||||
N = 32
|
||||
BLOCK_SIZE = 16
|
||||
with Kernel((1, 1, 1), WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
b = gl((1, 1, BLOCK_SIZE, N), dtypes.float32)
|
||||
a = gl((1, 1, BLOCK_SIZE, N), dtypes.float32)
|
||||
|
||||
a_smem = st((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
a_reg = rt((BLOCK_SIZE, BLOCK_SIZE), dtypes.float32)
|
||||
|
||||
max_vec_last = rv(BLOCK_SIZE, dtypes.float32, "ortho")
|
||||
max_vec = rv(BLOCK_SIZE, dtypes.float32, "ortho")
|
||||
norm_vec = rv(BLOCK_SIZE, dtypes.float32, "ortho")
|
||||
|
||||
max_vec = warp.neg_inf(max_vec)
|
||||
norm_vec = warp.zero(norm_vec)
|
||||
|
||||
for tile_col in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
|
||||
a_reg = warp.map(a_reg, lambda x: x * (1.0 / math.log(2)))
|
||||
|
||||
max_vec_last = warp.copy(max_vec_last.after(tile_col), max_vec)
|
||||
max_vec = warp.row_reduce(max_vec, a_reg, lambda a, b: a.maximum(b))
|
||||
a_reg = warp.map(a_reg, lambda x, idx: (x - max_vec[idx[0], 0, (idx[2]%4)//2]).exp2())
|
||||
max_vec_last = warp.map(max_vec_last, lambda x, idx: (x - max_vec[*idx]).exp2())
|
||||
norm_vec = warp.map(norm_vec, lambda x, idx: x * max_vec_last[*idx])
|
||||
norm_vec = warp.row_reduce(norm_vec, a_reg, lambda a, b: a + b)
|
||||
norm_vec = ker.endrange()
|
||||
|
||||
for tile_col in ker.range(N // BLOCK_SIZE):
|
||||
a_smem = warp.load(a_smem, a, (), (0, 0, 0, tile_col), axis=2)
|
||||
a_reg = warp.load(a_reg, a_smem)
|
||||
|
||||
a_reg = warp.map(a_reg, lambda x: x * (1.0 / math.log(2)))
|
||||
a_reg = warp.map(a_reg, lambda x, idx: (x - max_vec[idx[0], 0, (idx[2]%4)//2]).exp2())
|
||||
a_reg = warp.map(a_reg, lambda x, idx: x / norm_vec[idx[0], 0, (idx[2]%4)//2])
|
||||
|
||||
a_smem = warp.store(a_smem, a_reg)
|
||||
b = warp.store(b, a_smem, (0, 0, 0, tile_col), (), axis=2)
|
||||
|
||||
sink = ker.finish()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.rand(1, 1, BLOCK_SIZE, N, dtype="float32")
|
||||
b = Tensor.empty(1, 1, BLOCK_SIZE, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
print(b.tolist())
|
||||
|
||||
ref = a.float().softmax(axis=3)
|
||||
print(ref.tolist())
|
||||
|
||||
np.testing.assert_allclose(b.numpy(), ref.numpy(), atol=1e-5, rtol=1e-5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user