From cf2bf1518d4bc60c0d1c671853613180c57eb4f3 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Mon, 14 Aug 2023 18:09:04 +0200 Subject: [PATCH 1/8] Better reshape (#1423) * do reshaping without merge_views and reshape masks * added tests * properly do reshaping of zero or negative masks * replace while loop with single expression * remove old condition * add more tests and comments * remove empty file --- test/unit/test_shapetracker.py | 90 +++++++++++++++++++++++++++++++++- tinygrad/shape/shapetracker.py | 69 ++++++++++++++++++-------- 2 files changed, 137 insertions(+), 22 deletions(-) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index e7d0851985..44dcb1aeca 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -122,7 +122,6 @@ class TestRealSimplifies(unittest.TestCase): View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)]) class TestIndexExpressions2d(unittest.TestCase): - def setUp(self): shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5 offsets = [0, 1, 15, 28, 10000] @@ -188,6 +187,7 @@ class TestIndexExpressions2d(unittest.TestCase): st.expand((base_shape[0], base_shape[1], base_shape[1])) self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset) self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset) + def test_permute_reshape_1(self): # This tests multiple views for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): st.permute((1, 0)) @@ -417,6 +417,94 @@ class TestMaskedShapeTracker(unittest.TestCase): self.st.pad(((1,1), (1,1))) self.st.assert_same() + def test_reshaping_splitting(self): + self.st = CheckingShapeTracker((5,10,5,10)) + self.st.permute((1, 0, 3, 2)) + self.st.pad(((0,0), (0,5), (0,0), (0,5))) + self.st.reshape((10,2,5,10,2,5)) + assert len(self.st.views) == 1 + self.st.assert_same() + + def test_reshape_combining_1(self): + self.st = CheckingShapeTracker((2,1,10)) + self.st.pad(((2,6), (0,0), (0,0))) + self.st.reshape((100,)) + assert len(self.st.views) == 1 + self.st.assert_same() + + @unittest.skip("Can't make this optimization yet") + def test_reshape_combining_2(self): + self.st = CheckingShapeTracker((1,1,5)) + self.st.pad(((3,6), (0,0), (0,5))) + self.st.reshape((100,)) + assert len(self.st.views) == 1 + self.st.assert_same() + + @unittest.skip("Can't make this optimization yet") + def test_reshape_splitting_combining(self): + self.st = CheckingShapeTracker((1,5,5)) + self.st.pad(((0,4), (0,5), (0,0))) + self.st.reshape((10,25)) + assert len(self.st.views) == 1 + self.st.assert_same() + + def test_reshape_only_1s(self): + self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1)) + self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0))) + self.st.reshape((5, 6, 3, 5)) + assert len(self.st.views) == 1 + self.st.assert_same() + self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1)) + assert len(self.st.views) == 1 + self.st.assert_same() + self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1)) + assert len(self.st.views) == 1 + self.st.assert_same() + + def test_zero_mask_1(self): + self.st = CheckingShapeTracker((1, 3, 2)) + self.st.pad(((0,0), (0,3), (0,0))) + self.st.shrink(((0,1), (3,6), (0,2))) + self.st.reshape((3,2)) + self.st.assert_same() + self.st.reshape((1, 3, 1, 2, 1)) + self.st.assert_same() + + def test_zero_mask_2(self): + self.st = CheckingShapeTracker((1, 3, 2)) + self.st.pad(((0,2), (0,3), (0,0))) + self.st.shrink(((2,3), (3,6), (0,2))) + self.st.reshape((3,2)) + self.st.assert_same() + self.st.reshape((1, 3, 1, 2, 1)) + self.st.assert_same() + + def test_expanded_reshaped(self): + self.st = CheckingShapeTracker((1, 3, 2, 1)) + self.st.expand((5, 3, 2, 2)) + self.st.pad(((0,0), (0,3), (0,0), (0, 0))) + self.st.reshape((5, 2, 3, 2, 2)) + assert len(self.st.views) == 1 + self.st.assert_same() + + def test_splitting_big(self): + self.st = CheckingShapeTracker((1, 5, 1, 15, 1)) + self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0))) + self.st.reshape((10, 1, 30)) + self.st.permute((2,1,0)) + self.st.reshape((2,3,5,2,5)) + assert len(self.st.views) == 1 + v = self.st.views[-1] + assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5)) + + def test_combining_big(self): + self.st = CheckingShapeTracker((1,3,1,5,3,1)) + self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0))) + self.st.reshape((1,1,1,105,1,1)) + assert len(self.st.views) == 1 + v = self.st.views[-1] + assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)), v.offset == -30 + class TestShapeTracker(unittest.TestCase): def setUp(self): self.st = CheckingShapeTracker((7,4)) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 7428f453aa..2a06902e7e 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -95,30 +95,58 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]: if None in strides: return None return View(vm1.shape, strides, mst.real_offset(), vm1.mask) +def _reshape_mask(view: View, new_shape:Tuple[int, ...]) -> Tuple[Optional[Tuple[Tuple[int, int], ...]],bool]: + # assumes view can be reshaped to new_shape (if it had no mask), this implies we won't have to worry about strides + if view.mask is None: return view.mask, False + new_mask: List[Tuple[int, int]] = [] + r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int],...], view.mask)), reversed(view.shape), reversed(new_shape) + stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) + while len(new_mask) < len(new_shape): + if mask[1]-mask[0] < 1: # if the mask is never valid, just return all zeros + return ((0,0),)*len(new_shape), False + if old_dim == new_dim*stride: # easy, can just copy the mask + new_mask.append((mask[0]//stride, (mask[1]-1)//stride+1)) + stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) + elif old_dim > new_dim: # splitting the old mask + # we cannot split if the reshape cuts across the mask + if (mask[0]%(new_dim*stride)!=0 or mask[1]%(new_dim*stride)!=0) and mask[0]//(new_dim*stride)!=(mask[1]-1)//(new_dim*stride): + return view.mask, True + new_mask.append((mask[0]%(new_dim*stride)//stride, (mask[1]-1)%(new_dim*stride)//stride+1)) + # the remaining mask still needs to be split, we need to determine the mask for the next dimension + # we maintain the stride + stride *= new_dim + new_dim = next(r_new_shape, 1) + elif old_dim < new_dim*stride: # combining masks + next_mask = next(r_masks, (0,1)) + # if the current dimension is masked, we cannot merge unless the next masks have an index range of 1 + if (mask[0]!=0 or mask[1]!=old_dim) and next_mask[1]-next_mask[0]!=1: + return view.mask, True + # we combine the current mask with the next and go through the loop again with the next dimension + mask = (next_mask[0]*old_dim+mask[0], (next_mask[1]-1)*old_dim+mask[1]) + old_dim *= next(r_shape, 1) + for mask in (mask, *r_masks): # if the old shape has leading 1s, need to make sure their mask is (0,1), otherwise the mask is zero'd + if mask != (0,1): return ((0,0),)*len(new_shape), False + return tuple(reversed(new_mask)), False + @functools.lru_cache(maxsize=None) def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]: - shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset - # check if this is adding or removing 1s (only) - # NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional) - if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]: - new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1] - new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape]) - new_mask_tuple = None - if mask: - for x,y in zip(shape, mask): - if x == 1 and y != (0, 1): - new_mask_tuple = ((0,0),) * len(new_shape) - break - else: - new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1] - new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape]) - return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False + if view.contiguous or (view.shape == new_shape): return View(new_shape), False + strides, reverse_shape = [], reversed(new_shape) + for d, s in reversed(view.shape_strides): + acc, new_stride = 1, s + while acc < d: + new_dim = next(reverse_shape) + acc *= new_dim + strides.append(new_stride) + new_stride *= new_dim + if acc != d: break + else: + strides += [0,] * (len(new_shape) - len(strides)) + mask, extra = _reshape_mask(view, new_shape) + if not extra: return View(new_shape, tuple(reversed(strides)), view.offset, mask), False - new_view = View(new_shape) - if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset - if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}") - return new_view, True + return View(new_shape), True @functools.lru_cache(maxsize=None) def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]): @@ -147,7 +175,6 @@ class ShapeTracker: # this is the real size (ish) def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0]) - # these are multiview strides, value is None if it's not a simple strided dimension # TODO: this can be shared code between simplify and merge_views def real_offset(self) -> int: real_offset, mask = self.expr_node(Variable('zero', 0, 0)) From 9cb2bda34fb2c5e03d4f9c494b004dddc759ca84 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Mon, 14 Aug 2023 13:04:54 -0400 Subject: [PATCH 2/8] Revert "Better reshape (#1423)" (#1538) --- test/unit/test_shapetracker.py | 90 +--------------------------------- tinygrad/shape/shapetracker.py | 69 ++++++++------------------ 2 files changed, 22 insertions(+), 137 deletions(-) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 44dcb1aeca..e7d0851985 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -122,6 +122,7 @@ class TestRealSimplifies(unittest.TestCase): View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)]) class TestIndexExpressions2d(unittest.TestCase): + def setUp(self): shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5 offsets = [0, 1, 15, 28, 10000] @@ -187,7 +188,6 @@ class TestIndexExpressions2d(unittest.TestCase): st.expand((base_shape[0], base_shape[1], base_shape[1])) self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset) self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset) - def test_permute_reshape_1(self): # This tests multiple views for st, base_shape, offset in zip(self.sts, self.shapes, self.offset): st.permute((1, 0)) @@ -417,94 +417,6 @@ class TestMaskedShapeTracker(unittest.TestCase): self.st.pad(((1,1), (1,1))) self.st.assert_same() - def test_reshaping_splitting(self): - self.st = CheckingShapeTracker((5,10,5,10)) - self.st.permute((1, 0, 3, 2)) - self.st.pad(((0,0), (0,5), (0,0), (0,5))) - self.st.reshape((10,2,5,10,2,5)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_reshape_combining_1(self): - self.st = CheckingShapeTracker((2,1,10)) - self.st.pad(((2,6), (0,0), (0,0))) - self.st.reshape((100,)) - assert len(self.st.views) == 1 - self.st.assert_same() - - @unittest.skip("Can't make this optimization yet") - def test_reshape_combining_2(self): - self.st = CheckingShapeTracker((1,1,5)) - self.st.pad(((3,6), (0,0), (0,5))) - self.st.reshape((100,)) - assert len(self.st.views) == 1 - self.st.assert_same() - - @unittest.skip("Can't make this optimization yet") - def test_reshape_splitting_combining(self): - self.st = CheckingShapeTracker((1,5,5)) - self.st.pad(((0,4), (0,5), (0,0))) - self.st.reshape((10,25)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_reshape_only_1s(self): - self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1)) - self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0))) - self.st.reshape((5, 6, 3, 5)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() - self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_zero_mask_1(self): - self.st = CheckingShapeTracker((1, 3, 2)) - self.st.pad(((0,0), (0,3), (0,0))) - self.st.shrink(((0,1), (3,6), (0,2))) - self.st.reshape((3,2)) - self.st.assert_same() - self.st.reshape((1, 3, 1, 2, 1)) - self.st.assert_same() - - def test_zero_mask_2(self): - self.st = CheckingShapeTracker((1, 3, 2)) - self.st.pad(((0,2), (0,3), (0,0))) - self.st.shrink(((2,3), (3,6), (0,2))) - self.st.reshape((3,2)) - self.st.assert_same() - self.st.reshape((1, 3, 1, 2, 1)) - self.st.assert_same() - - def test_expanded_reshaped(self): - self.st = CheckingShapeTracker((1, 3, 2, 1)) - self.st.expand((5, 3, 2, 2)) - self.st.pad(((0,0), (0,3), (0,0), (0, 0))) - self.st.reshape((5, 2, 3, 2, 2)) - assert len(self.st.views) == 1 - self.st.assert_same() - - def test_splitting_big(self): - self.st = CheckingShapeTracker((1, 5, 1, 15, 1)) - self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0))) - self.st.reshape((10, 1, 30)) - self.st.permute((2,1,0)) - self.st.reshape((2,3,5,2,5)) - assert len(self.st.views) == 1 - v = self.st.views[-1] - assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5)) - - def test_combining_big(self): - self.st = CheckingShapeTracker((1,3,1,5,3,1)) - self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0))) - self.st.reshape((1,1,1,105,1,1)) - assert len(self.st.views) == 1 - v = self.st.views[-1] - assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)), v.offset == -30 - class TestShapeTracker(unittest.TestCase): def setUp(self): self.st = CheckingShapeTracker((7,4)) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 2a06902e7e..7428f453aa 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -95,58 +95,30 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]: if None in strides: return None return View(vm1.shape, strides, mst.real_offset(), vm1.mask) -def _reshape_mask(view: View, new_shape:Tuple[int, ...]) -> Tuple[Optional[Tuple[Tuple[int, int], ...]],bool]: - # assumes view can be reshaped to new_shape (if it had no mask), this implies we won't have to worry about strides - if view.mask is None: return view.mask, False - new_mask: List[Tuple[int, int]] = [] - r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int],...], view.mask)), reversed(view.shape), reversed(new_shape) - stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) - while len(new_mask) < len(new_shape): - if mask[1]-mask[0] < 1: # if the mask is never valid, just return all zeros - return ((0,0),)*len(new_shape), False - if old_dim == new_dim*stride: # easy, can just copy the mask - new_mask.append((mask[0]//stride, (mask[1]-1)//stride+1)) - stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) - elif old_dim > new_dim: # splitting the old mask - # we cannot split if the reshape cuts across the mask - if (mask[0]%(new_dim*stride)!=0 or mask[1]%(new_dim*stride)!=0) and mask[0]//(new_dim*stride)!=(mask[1]-1)//(new_dim*stride): - return view.mask, True - new_mask.append((mask[0]%(new_dim*stride)//stride, (mask[1]-1)%(new_dim*stride)//stride+1)) - # the remaining mask still needs to be split, we need to determine the mask for the next dimension - # we maintain the stride - stride *= new_dim - new_dim = next(r_new_shape, 1) - elif old_dim < new_dim*stride: # combining masks - next_mask = next(r_masks, (0,1)) - # if the current dimension is masked, we cannot merge unless the next masks have an index range of 1 - if (mask[0]!=0 or mask[1]!=old_dim) and next_mask[1]-next_mask[0]!=1: - return view.mask, True - # we combine the current mask with the next and go through the loop again with the next dimension - mask = (next_mask[0]*old_dim+mask[0], (next_mask[1]-1)*old_dim+mask[1]) - old_dim *= next(r_shape, 1) - for mask in (mask, *r_masks): # if the old shape has leading 1s, need to make sure their mask is (0,1), otherwise the mask is zero'd - if mask != (0,1): return ((0,0),)*len(new_shape), False - return tuple(reversed(new_mask)), False - @functools.lru_cache(maxsize=None) def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]: - if view.contiguous or (view.shape == new_shape): return View(new_shape), False - strides, reverse_shape = [], reversed(new_shape) - for d, s in reversed(view.shape_strides): - acc, new_stride = 1, s - while acc < d: - new_dim = next(reverse_shape) - acc *= new_dim - strides.append(new_stride) - new_stride *= new_dim - if acc != d: break - else: - strides += [0,] * (len(new_shape) - len(strides)) - mask, extra = _reshape_mask(view, new_shape) - if not extra: return View(new_shape, tuple(reversed(strides)), view.offset, mask), False + shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset + # check if this is adding or removing 1s (only) + # NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional) + if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]: + new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1] + new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape]) + new_mask_tuple = None + if mask: + for x,y in zip(shape, mask): + if x == 1 and y != (0, 1): + new_mask_tuple = ((0,0),) * len(new_shape) + break + else: + new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1] + new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape]) + return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False + new_view = View(new_shape) + if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset + if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}") - return View(new_shape), True + return new_view, True @functools.lru_cache(maxsize=None) def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]): @@ -175,6 +147,7 @@ class ShapeTracker: # this is the real size (ish) def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0]) + # these are multiview strides, value is None if it's not a simple strided dimension # TODO: this can be shared code between simplify and merge_views def real_offset(self) -> int: real_offset, mask = self.expr_node(Variable('zero', 0, 0)) From a453d718a172dab88acac29b111d393176cde659 Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Tue, 15 Aug 2023 03:50:10 +0200 Subject: [PATCH 3/8] fix file race condition in ops_clang via pid in the filename (#1541) * fix file race condition in ops_clang via pid in the filename as suggested in https://github.com/tinygrad/tinygrad/pull/1458/files#r1292819054 * add explanation why a temp file is required on ops_clang --- tinygrad/runtime/ops_clang.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 587f45f3e3..b6f7fb90a8 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -15,9 +15,11 @@ class ClangProgram: def __init__(self, name:str, prg:str): prg = CLANG_PROGRAM_HEADER + prg # TODO: is there a way to not write this to disk? + # A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file + # because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}" if not os.path.exists(fn): - _, tmp = tempfile.mkstemp() + tmp = f"{fn}.{os.getpid()}.tmp" subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8')) os.rename(tmp, fn) self.lib = ctypes.CDLL(fn) From a89142e46f579727329b33271b72db68cb4848b1 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 14 Aug 2023 18:53:37 -0700 Subject: [PATCH 4/8] ShapeTracker.var_vals (#1540) --- test/test_lazybuffer.py | 24 ------- test/test_symbolic_shapetracker.py | 106 ++++++++++++++++++++--------- test/unit/test_symbolic.py | 10 +-- tinygrad/lazy.py | 11 ++- tinygrad/shape/shapetracker.py | 28 ++++---- tinygrad/shape/symbolic.py | 2 - 6 files changed, 96 insertions(+), 85 deletions(-) diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index cc71a0d0f4..19afc25de5 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -68,29 +68,5 @@ class TestLazyBuffer(unittest.TestCase): assert GlobalCounters.cache[2][0].name.startswith("E_") GlobalCounters.cache = None -class TestVariableBuffer(unittest.TestCase): - def test_get_variable_buffers_no_variable(self): - t = Tensor.rand(2, 3) - assert t.lazydata.get_variable_buffers() == {} - - def test_get_variable_buffers_one_variable(self): - v = Variable("v", 1, 10) - t = Tensor.rand(2, 3).reshape(v, 3) - buffers = t.lazydata.get_variable_buffers() - assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 2 - v = Variable("v", 1, 10) - t = Tensor.rand(2, 3).reshape(2, v) - buffers = t.lazydata.get_variable_buffers() - assert len(buffers) == 1 and buffers[v].realize().realized.toCPU() == 3 - - def test_get_variable_buffers_cat(self): - v1 = Variable("v1", 1, 10) - v2 = Variable("v2", 1, 10) - t1 = Tensor.rand(2, 3).reshape(v1, 3) - t2 = Tensor.rand(6, 3).reshape(v2, 3) - t = t1.cat(t2) - buffers = t.lazydata.get_variable_buffers() - assert len(buffers) == 2 and buffers[v1].realize().realized.toCPU() == 2 and buffers[v2].realize().realized.toCPU() == 6 - if __name__ == "__main__": unittest.main() diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 7a0ee876bb..8a8857b97b 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -40,65 +40,91 @@ class TestSymbolic(unittest.TestCase): class TestSymbolicReshape(unittest.TestCase): def test_reshape_into_symbols_simple(self): - for i in range(1, 5): - vi = Variable("i", 1, 10) - assert Tensor.rand(i, 4).reshape(vi, 4).shape == (vi, 4) - assert vi.val == i - vi = Variable("i", 1, 10) - assert Tensor.rand(i, 6).reshape(vi, 2, 3).shape == (vi, 2, 3) - assert vi.val == i + vi = Variable("i", 1, 5) + for i in range(1, 6): + t = Tensor.rand(i, 4).reshape(vi, 4) + assert t.shape == (vi, 4) + assert t.lazydata.st.var_vals[vi] == i + t = Tensor.rand(i, 6).reshape(vi, 2, 3) + assert t.shape == (vi, 2, 3) + assert t.lazydata.st.var_vals[vi] == i def test_reshape_symbols_reshape_ints(self): - for i in range(1, 5): - vi = Variable("i", 1, 10) - assert Tensor.rand(i, 4).reshape(vi, 4).reshape(i, 4).shape == (i, 4) - assert Tensor.rand(i, 4).reshape(vi, 4).reshape(i*4,).shape == (i*4,) - assert Tensor.rand(i, 6).reshape(vi, 6).reshape(i*2, 3).shape == (i*2, 3) - with self.assertRaises(AssertionError): - Tensor.rand(i, 6).reshape(vi, 6).reshape(1, 77).shape + vi = Variable("i", 1, 5) + for i in range(1, 6): + t = Tensor.rand(i, 4).reshape(vi, 4) + assert t.shape == (vi, 4) + assert t.lazydata.st.var_vals == {vi: i} + t = t.reshape(i, 4) + assert t.shape == (i, 4) + assert t.lazydata.st.var_vals == {} def test_reshape_reuse_var_same_value_ok(self): - for i in range(1, 5): - vi = Variable("i", 1, 10) + vi = Variable("i", 1, 5) + for i in range(1, 6): a = Tensor.rand(i, 4).reshape(vi, 4) b = Tensor.rand(i, 3).reshape(vi, 3) - assert vi.val == i + assert a.lazydata.st.var_vals[vi] == i + assert b.lazydata.st.var_vals[vi] == i - def test_reshape_reuse_var_different_value_fail(self): - for i in range(1, 5): - vi = Variable("i", 1, 10) + def test_reshape_reuse_var_different_value_ok(self): + vi = Variable("i", 1, 10) + for i in range(1, 6): a = Tensor.rand(i, 4).reshape(vi, 2) - with self.assertRaises(AssertionError): - b = Tensor.rand(i, 3).reshape(vi, 3) + b = Tensor.rand(i, 3).reshape(vi, 3) + # a and b have different values of vi + assert a.lazydata.st.var_vals[vi] == 2 * i + assert b.lazydata.st.var_vals[vi] == i def test_reshape_into_symbols_bad_shape(self): vi = Variable("i", 1, 10) vj = Variable("j", 1, 10) with self.assertRaises(AssertionError): - t = Tensor.rand(3, 4).reshape(vi, vj) + t = Tensor.rand(3, 4).reshape(vi, vj) # reshape into two variables with self.assertRaises(AssertionError): - t = Tensor.rand(4, 4).reshape(vi, vi) + t = Tensor.rand(4, 4).reshape(vi, vi) # reshape into same variable in 2 dimensions with self.assertRaises(AssertionError): - t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4) + t = Tensor.rand(4, 6).reshape(vi, 6).reshape(vi, 4) # conflicted implied variable values + with self.assertRaises(AssertionError): + t = Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape with self.assertRaises(AssertionError): t = Tensor.rand(100, 4).reshape(Variable("too_small", 1, 10), 4) with self.assertRaises(AssertionError): t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4) + with self.assertRaises(AssertionError): + t = Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node + + def test_two_symbol_reshape(self): + vi = Variable("i", 1, 5) + vj = Variable("j", 1, 5) + for i in range(1, 6): + for j in range(1, 6): + t1 = Tensor.rand(i, 5).reshape(vi, 5) + t2 = Tensor.rand(5, j).reshape(5, vj) + t = t1@t2 + assert t.shape == (vi, vj) + t = t.reshape(1, vi*vj) + assert t.shape == (1, vi*vj) + t = t.reshape(vj, vi) + assert t.shape == (vj, vi) class TestSymbolicExpand(unittest.TestCase): def test_expand_into_symbols(self): - vi = Variable("i", 1, 10) + vi = Variable("i", 1, 5) + vj = Variable("j", 1, 5) a = Tensor([[1], [2], [3]]).expand((3, vi)) assert a.shape == (3, vi) - vj = Variable("j", 1, 10) + assert a.lazydata.st.var_vals == {} a = a.reshape(3, vi, 1).expand((3, vi, vj)) assert a.shape == (3, vi, vj) + assert a.lazydata.st.var_vals == {} def test_plus_expands_constant(self): - vi = Variable("i", 1, 10) - a = Tensor.rand(3, 4).reshape(3, vi) - a = a + 1 - assert a.shape == (3, vi) + vi = Variable("i", 1, 5) + for i in range(1, 6): + a = Tensor.rand(3, i).reshape(3, vi) + a = a + 1 + assert a.shape == (3, vi) class TestSymbolicShapeExpr(unittest.TestCase): def test_symbolic_expr_idxs(self): @@ -114,5 +140,23 @@ class TestSymbolicShapeExpr(unittest.TestCase): idx, valid = st.expr_idxs(idx) assert idx.render() == "(((1+i)*1)+(lidx1*((i*4)+4))+gidx0)" +class TestShapeTrackerVarVals(unittest.TestCase): + def test_reshape_reshape_updates_var_vals(self): + vi = Variable("i", 1, 5) + vj = Variable("j", 1, 5) + t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj) + assert t.lazydata.st.var_vals == {vi: 4, vj: 3} + + def test_lazy_check_var_vals(self): + vi = Variable("i", 1, 5) + a = Tensor.rand(3, 4).reshape(3, vi) + b = Tensor.rand(5, 6).reshape(vi, 6) + assert a.lazydata.st.var_vals == {vi: 4} + assert b.lazydata.st.var_vals == {vi: 5} + c = a@b + # shapetracker works with symbolic shape and doesn't check / propagate the underlying variable values + assert c.shape == (3, 6) + assert c.lazydata.st.var_vals == {} + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 22c64d2434..d2ec6810c5 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import unittest -from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_vars, sym_render +from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_render class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s): @@ -261,14 +261,6 @@ class TestSymbolicVars(unittest.TestCase): assert (a % 3 + b // 5).vars() == [a, b] assert (a + b + c - a).vars() == [b, c] - def test_sym_vars(self): - a = Variable("a", 0, 10) - b = Variable("b", 0, 10) - assert sym_vars(1) == [] - assert sym_vars(a) == [a] - assert sym_vars(a+b) == [a, b] - assert sym_vars(a*3) == [a] - class TestSymbolicMinMax(unittest.TestCase): def test_min_max_known(self): a = Variable("a", 1, 8) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 1fdb63b4a0..99960ad2d1 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -9,7 +9,7 @@ from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, from tinygrad.runtime.ops_cpu import RawNumpyBuffer from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction -from tinygrad.shape.symbolic import Variable, sym_vars +from tinygrad.shape.symbolic import Node from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps, OpType, LazyOp from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer @@ -214,7 +214,7 @@ class LazyBuffer: if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype) - def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer: + def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer: if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0: return self.op.replace_with_movement_ops([(op, arg)]) ret = create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype) @@ -231,13 +231,13 @@ class LazyBuffer: return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype) def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer: - if prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach. + if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach. heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore if divisor < 16 or heuristic < 0.125: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides. def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:] return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape) - def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer: if self.shape == arg: return self if not self.realized and self.op.op == MovementOps.RESHAPE: self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? @@ -249,7 +249,7 @@ class LazyBuffer: if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)])) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg) - def expand(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + def expand(self: LazyBuffer, arg:Tuple[Union[Node,int], ...]) -> LazyBuffer: if self.shape == arg: return self if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg) @@ -293,7 +293,6 @@ class LazyBuffer: def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,) def map_buffers(self, real_srcs: Dict[Any, Any]): return real_srcs.get(self, self) def get_lazyops(self) -> List[Any]: return [] - def get_variable_buffers(self) -> Dict[Variable, LazyBuffer]: return {v:LazyBuffer.loadop(LoadOps.FROM, (1,), dtypes.int32, self.device, src=LazyBuffer.fromCPU(np.array([v.val], dtype=np.int32))) for s in self.shape for v in sym_vars(s)} def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer: y = self for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 7428f453aa..beba5c4f17 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -3,7 +3,7 @@ from __future__ import annotations from enum import Enum, auto import functools from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple -from tinygrad.helpers import prod, DEBUG +from tinygrad.helpers import prod, DEBUG, partition from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int # these ops live here @@ -129,17 +129,18 @@ def get_unsafe_resize_offset(strides, arg): return sum([s * x[0] for s, x in zip(strides,arg)]) class ShapeTracker: - __slots__ = "views" - def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None): + __slots__ = "views", "var_vals" + def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None): self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [View(shape)]) - def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})" + self.var_vals: Dict[Variable, int] = shape.var_vals if isinstance(shape, ShapeTracker) else {} + def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views}, var_vals={self.var_vals})" def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views]) @property def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous @property - def shape(self) -> Tuple[int, ...]: return self.views[-1].shape + def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape) @property def key(self) -> Tuple[View, ...]: return tuple(self.views) @@ -231,15 +232,16 @@ class ShapeTracker: return self def reshape(self, new_shape: Tuple[Union[Node,int], ...]): - # reshape into symbolic shape, update the variable value - if all(isinstance(s, int) for s in self.shape) and len(new_vars:=list(s for s in new_shape if isinstance(s, Variable))) > 0: - assert len(new_vars) == 1, "only one variable is supported in a shape" - new_var, new_val = new_vars[0], prod(self.shape) // prod(s for s in new_shape if isinstance(s, int)) - if new_var.val is None: + new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int)) + if new_nodes and all(isinstance(s, int) for s in self.shape): + # reshape from all int shape into shape with a variable, update the variable value + assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape" + new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints) + if new_var not in self.var_vals: assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]" - new_var.val = new_val - else: assert new_var.val == new_val, f"value conflicts, was {new_var.val}, set to {new_val}" - + self.var_vals[new_var] = new_val + else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}" + elif not new_nodes: self.var_vals = {} if self.views[-1].shape == new_shape: return self assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}" # only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index b5fb04a2d3..5d5e866015 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -9,7 +9,6 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node)) -def sym_vars(x: Union[Node, int]) -> List[Variable]: return [] if isinstance(x, int) else x.vars() class Node: b: Union[Node, int] @@ -141,7 +140,6 @@ class Variable(Node): def __init__(self, expr:Optional[str], nmin:int, nmax:int): self.expr, self.min, self.max = expr, nmin, nmax - self.val: Optional[int] = None def vars(self): return [self] class NumNode(Node): From 93a36c36598f9d7ac4f1a8a679d91bd5b75f4c04 Mon Sep 17 00:00:00 2001 From: Steven Anderson <34435120+stevenandersonz@users.noreply.github.com> Date: Mon, 14 Aug 2023 22:29:30 -0400 Subject: [PATCH 5/8] Arm (#1421) * testing new memops * better debugging * testing padded conv * branching with load * refactoring a bit * first try * fixing bugs * fixing some * eq * eq2 * do not use x's * working * fixing imm * getting things working * refactor * pow not working * working except one * refactor: one store mem * refactor: global load * refactor: imm * refactor: cleaning * fixing big offsets * refactor with ci * try ci * typo * another typo * ubuntu default * forgot git * do i need git? * missing packages * adding python-dev * with cache? * buildx action * buildx name issue? * maybe now? * python3 * newline warning * maybe now * i actually need this * ci should work now * improved caching * fixing cache * maybe now it will cache * this * testing cache * trying again * load * missing platform * caching gha * testing cache * full testing * typo * now? * why * adding checkout back * bad formatting * fixing convention issues * supporting python * adding CI flag * testing all * better comments * adding debugging * takes 12x longer * does it output progress now? * ignore models for speed * fixing merge * excluding conv_transpose2d * only 2 test cuz is to slow * another approach * let's see * faster duh * my bad * T_T * typo * sup * with output? * comment test * comment test * comment test * :? * no comment * with cache * back to normal * testing that ci works * back to passing * trying again * does it create another entry * does it create another entry? * build local * hey * Revert "excluding conv_transpose2d" This reverts commit cc7348de03033e032f47d69caff174e2f1a7bfea. * does it cache if done before? * does it cache? * done * adding test ops * bad formatting * no need for this * working static mem * sum 1d * add ndim * better reg import * fix stack * back to np * working except for softmax * 5 failing * no pogress * remove keystone * remove keystone * testops passing * cleanups * more cleanup * typo * ci * ci2 * cond import * ci3 * ci4 * ci4 * ci5 * ci5 * ci6 * aligment * test all * correct test * err read_unmapped * passing test * ignore for speed * ignore for speed * ci7 * cleanup * remove docker * fixing merge * fixing bugs * add skipload for const ops * comments * First merge to master: Renderer * fix emulation * passing all tests arm64 * cleaning * fix handcoded binary * cleaning * fix errs * fix runtime arg binary * clean git diff * fix and clean * fixing metal test * cleaning * fix metal test * ci ~8 min * fix pylint and clang * cache the files in ops_clang --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- .github/workflows/test.yml | 25 ++ extra/assembly/assembly.py | 290 +++++++++++---------- extra/assembly/assembly_arm64.py | 171 ++++++++++++ setup.py | 1 + test/external/external_test_speed_llama.py | 2 +- test/test_uops.py | 7 +- tinygrad/ops.py | 6 +- tinygrad/runtime/ops_clang.py | 61 ++++- tinygrad/runtime/ops_metal.py | 2 +- 9 files changed, 409 insertions(+), 156 deletions(-) create mode 100644 extra/assembly/assembly_arm64.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8ac1e08db6..ce82453c69 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -267,3 +267,28 @@ jobs: - name: Run pytest (cuda) if: matrix.backend=='cuda' run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors) and not (test_conv2d and test_tensor.py)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models + + testunicorn: + name: ARM64 unicorn Test + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - name: Checkout Code + uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Cache pip + uses: actions/cache@v3 + with: + path: '~/.cache/pip' + key: unicorn + - name: Install cross-assembler + run: | + sudo apt-get update -y && \ + sudo apt-get install -y --no-install-recommends gcc-aarch64-linux-gnu + - name: Install dependencies + run: pip install -e '.[testing,arm]' --extra-index-url https://download.pytorch.org/whl/cpu + - name: Test arm + run: CI=1 ARM64=1 CLANG=1 python -m pytest -n=auto test/ -k 'not (test_nn.py and (test_conv_transpose2d or test_conv2d))' --ignore=test/models --ignore=test/test_speed_v_torch.py --ignore=test/test_net_speed.py --ignore=test/test_specific_conv.py --ignore=test/unit/test_disk_tensor.py \ No newline at end of file diff --git a/extra/assembly/assembly.py b/extra/assembly/assembly.py index 978119d031..d2b9fedb0c 100644 --- a/extra/assembly/assembly.py +++ b/extra/assembly/assembly.py @@ -1,7 +1,7 @@ from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict -from tinygrad.codegen.linearizer import Linearizer, UOps, Token -from tinygrad.ops import ASTRunner, BinaryOps, UnaryOps -from tinygrad.helpers import DType, dtypes, DEBUG +from tinygrad.codegen.linearizer import Linearizer, UOps, Token, ConstOp, MemOp, UOp +from tinygrad.ops import BinaryOps, UnaryOps +from tinygrad.helpers import DType, dtypes, DEBUG, getenv from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode import functools import math @@ -20,6 +20,7 @@ class Register(NamedTuple): if self.dtype == dtypes._float4: return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)] return [] + class AssemblyInstruction(NamedTuple): op: UOps out: Optional[Register] @@ -27,156 +28,161 @@ class AssemblyInstruction(NamedTuple): arg: Any = None # warp size of 32, s registers are shared across the warp, v are 32-wide vectors -class AssemblyCodegen(Linearizer): +class AssemblyLanguage(NamedTuple): supports_load3: bool = False sin_is_sin2pi: bool = False no_div: bool = False + #TODO: these should be global vars + cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int) + tor: Dict[Any, Register] = {} + ins = [] - def specialize(self, asm:List[AssemblyInstruction]) -> Tuple[str, str]: - raise NotImplementedError("must be implemented") + def newreg(self, tok, dtype=dtypes.float32, scalar=False): + if isinstance(tok, Token): dtype = tok.dtype # this + self.tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar) + if dtype == dtypes._float4: + for off in range(4): + self.tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off) + self.cnts[(dtype, scalar)] += 1 + return ret - # s registers are the addresses and non local indexes - def codegen(self): - self.process() - self.hand_coded_optimizations() - self.limit_global_dims(3) # all GPU asms have 3 (for now) - self.linearize() + def render_numnode(self, b): + key = ("num", b) + if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b)) + return self.tor[key] - cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int) - tor: Dict[Any, Register] = {} - def newreg(tok, dtype=dtypes.float32, scalar=False): - nonlocal cnts, tor - if isinstance(tok, Token): dtype = tok.dtype # this - tor[tok] = ret = Register(f"%{type_to_letter((dtype, scalar))}{cnts[(dtype, scalar)]}", dtype, scalar) - if dtype == dtypes._float4: - for off in range(4): - tor[Token(tok.name, tok.dtype, off)] = Register(ret.nm, dtypes.float, ret.scalar, off) - cnts[(dtype, scalar)] += 1 - return ret + def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: + key = (op, a, b) + if key not in self.tor: + #if not isinstance(b, Register): b = render_numnode(b) + self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) + return self.tor[key] - def render_numnode(b): - key = ("num", b) - if key not in tor: ins.append(AssemblyInstruction(UOps.CONST, newreg(key, scalar=True, dtype=dtypes.int32), [], b)) - return tor[key] + def render_cast(self, a:Register, new_dtype:DType) -> Register: + if a.dtype == new_dtype: return a + key = (a, new_dtype) + if key not in self.tor: + self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])) + return self.tor[key] - def render_alu(op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: - key = (op, a, b) - if key not in tor: - #if not isinstance(b, Register): b = render_numnode(b) - ins.append(AssemblyInstruction(UOps.ALU, newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) - return tor[key] + render_ops = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b), + MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b), + DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b), + ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b), + LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool), + SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), + AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } - def render_cast(a:Register, new_dtype:DType) -> Register: - if a.dtype == new_dtype: return a - key = (a, new_dtype) - if key not in tor: - ins.append(AssemblyInstruction(UOps.CAST, newreg(key, dtype=new_dtype), [a])) - return tor[key] + def addr_w_offset(self, args): + assert isinstance(args, MemOp) + idx = args.idx*args.memory_dtype.itemsize + off = 0 # TODO: should this be None? + if isinstance(idx, SumNode): + nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] + if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU? + idx -= nums[0] + off = nums[0] + reg = idx.render(self.render_ops, self) + if self.supports_load3: + if reg.scalar: + new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype) + self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)) + reg = new_reg + return self.tor[args.name], reg, off + reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64) + return reg, None, off - render_ops = { Variable: lambda self, ops, ctx: tor[self], NumNode: lambda self, ops, ctx: render_numnode(self.b), - MulNode: lambda self, ops, ctx: render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b), - DivNode: lambda self, ops, ctx: render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b), - ModNode: lambda self, ops, ctx: render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b), - LtNode: lambda self, ops, ctx: render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool), - SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)), - AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } +def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): + #TODO: Do not use clear() + lang.ins.clear() + lang.tor.clear() + buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL} + buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} + global_size, local_size = [], [] + skipload_branch = 0 + lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype] + for uop,newvar,vin,args in uops: + if uop == UOps.DEFINE_LOCAL: + lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) + lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) + elif uop == UOps.LOOP: + if args[1] == "global": + for i,var in enumerate(args[0]): + global_size.append(var.max+1) + lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) + elif args[1] == "local": + for i,var in enumerate(args[0]): + local_size.append(var.max+1) + lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) + else: + for var in args[0]: + if not isinstance(var, NumNode): # TODO: why is this coming through? + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) #FIXME: what should valid be here? + lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr)) + elif uop == UOps.ENDLOOP: + if args[1] not in ["global", "local", "global+local"]: + for var in reversed(args[0]): + if not isinstance(var, NumNode): # TODO: why is this coming through? + lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD)) + pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool) + lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) + elif args[1] == "global+local": + for i, var in enumerate(reversed(args[0])): + lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}"))) - def addr_w_offset(args): - idx = args.idx*self.bufs[args.i].dtype.itemsize - off = 0 # TODO: should this be None? - if isinstance(idx, SumNode): - nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] - if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU? - idx -= nums[0] - off = nums[0] - reg = idx.render(render_ops) - if self.supports_load3: - if reg.scalar: - new_reg = newreg((reg.nm, 'vec'), dtype=reg.dtype) - ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)) - reg = new_reg - return tor[f"buf{args.i}"], reg, off - reg = render_alu(BinaryOps.ADD, render_cast(reg, dtypes.uint64), tor[f"buf{args.i}"], dtype=dtypes.uint64) - return reg, None, off - - ins = [] - ins += [AssemblyInstruction(UOps.SPECIAL, newreg(f"buf{i}", dtype=dtypes.uint64, scalar=True), [], f"buf{i}") for i in range(len(self.bufs))] - global_size, local_size = [], [] - skipload_branch = 0 - for uop,newvar,vin,args in self.uops: - if uop == UOps.CONST and newvar is not None: - ins.append(AssemblyInstruction(UOps.CONST, newreg(newvar, dtype=newvar.dtype), [], args)) - elif uop == UOps.DEFINE_LOCAL: - ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) - ins.append(AssemblyInstruction(UOps.ALU, newreg("buf-1", dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) - elif uop == UOps.LOOP: - if args[1] == "global": - for i,var in enumerate(args[0]): - global_size.append(var.max+1) - ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) - elif args[1] == "local": - for i,var in enumerate(args[0]): - local_size.append(var.max+1) - ins.append(AssemblyInstruction(UOps.SPECIAL, newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) - else: - for var in args[0]: - if not isinstance(var, NumNode): # TODO: why is this coming through? - ins.append(AssemblyInstruction(UOps.CONST, newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) - ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr)) - elif uop == UOps.ENDLOOP: - if args[1] not in ["global", "local", "global+local"]: - for var in reversed(args[0]): - if not isinstance(var, NumNode): # TODO: why is this coming through? - ins.append(AssemblyInstruction(UOps.ALU, tor[var], [tor[var], 1], BinaryOps.ADD)) - pred = render_alu(BinaryOps.CMPLT, tor[var], var.max+1, dtypes.bool) - ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) - elif uop == UOps.CAST and newvar is not None: - # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies - out = newreg(newvar) - for i,sr in enumerate(out.subregs()): - ins.append(AssemblyInstruction(UOps.ALU, sr, [tor[vin[i]]], UnaryOps.NOOP)) - elif uop == UOps.ALU and newvar is not None: - out = newreg(newvar) if newvar not in tor else tor[newvar] - # this is the only thing that can violate SSA - if args in [BinaryOps.CMPLT]: - pred_reg = newreg((newvar, 'pred'), dtype=dtypes.bool) - ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [tor[x] for x in vin], args)) - ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) - elif args == BinaryOps.DIV and self.no_div: - tmp = newreg((newvar, "rcp")) - ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[1]]], UnaryOps.RECIP)) - ins.append(AssemblyInstruction(UOps.ALU, out, [tor[vin[0]], tmp], BinaryOps.MUL)) - elif args == UnaryOps.SIN and self.sin_is_sin2pi: - tmp = newreg((newvar, "2pi")) - ins.append(AssemblyInstruction(UOps.ALU, tmp, [tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) - ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args)) - else: - ins.append(AssemblyInstruction(UOps.ALU, out, [tor[x] for x in vin], args)) - elif uop == UOps.LOAD and newvar is not None: - idx, treg, off = addr_w_offset(args) - reg = newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype))) - if args.valid.min == 0: - ins.append(AssemblyInstruction(UOps.CONST, reg, [], 0)) - if args.valid.max == 1: - pred = args.valid.render(render_ops) - ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) - if args.valid.max == 1: - # NOTE: you can't compute the index in here, because it assumes it's all available later - ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared'))) + elif uop == UOps.CAST and newvar is not None: + # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies + out = lang.newreg(newvar) + for i,sr in enumerate(out.subregs()): + lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)) + elif uop == UOps.ALU and newvar is not None: + out = lang.newreg(newvar) if newvar not in lang.tor else lang.tor[newvar] + # this is the only thing that can violate SSA + if args in [BinaryOps.CMPLT]: + pred_reg = lang.newreg((newvar, 'pred'), dtype=dtypes.bool) + lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args)) + lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) + elif args == BinaryOps.DIV and lang.no_div: + tmp = lang.newreg((newvar, "rcp")) + lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP)) + lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL)) + elif args == UnaryOps.SIN and lang.sin_is_sin2pi: + tmp = lang.newreg((newvar, "2pi")) + lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) + lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args)) + else: + lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)) + elif uop == UOps.LOAD and newvar is not None: + if isinstance(args, ConstOp): if args.valid.min == 0 and args.valid.max == 1: - ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.invalid_value)) + pred = args.valid.render(lang.render_ops, lang) + lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value)) + lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) skipload_branch += 1 - elif uop == UOps.STORE: - idx, treg, off = addr_w_offset(args) - ins.append(AssemblyInstruction(UOps.STORE, None, [idx, tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if args.i != -1 else 'shared'))) + else: + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(newvar, dtype=newvar.dtype), [], args.value if args.valid.min == 1 else args.invalid_value)) + else: + idx, treg, off = lang.addr_w_offset(args) + reg = lang.newreg(newvar, dtype=newvar.dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) # and not dtypes.is_float(newvar.dtype))) + if args.valid.min == 0: + lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0)) + if args.valid.max == 1: + pred = args.valid.render(lang.render_ops, lang) + lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) + if args.valid.max == 1: + # NOTE: you can't compute the index in here, because it assumes it's all available later + lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype[args.name] != dtypes.float else None))) + if args.valid.min == 0 and args.valid.max == 1: + lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) + skipload_branch += 1 + elif uop == UOps.STORE: + idx, treg, off = lang.addr_w_offset(args) + lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if buf_index[args.name] != -1 else 'shared', args.memory_dtype if buf_to_dtype['data0'] != dtypes.float else None))) + # define registers + lang.ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in lang.cnts.items()] + lang.ins - # define registers - ins = [AssemblyInstruction(UOps.DEFINE_REGISTER, None, [], (dtype, type_to_letter(dtype), c)) for dtype,c in cnts.items()] + ins - - if DEBUG >= 4: - for tins in ins: print(tins) - name, asm = self.specialize(ins) - - return ASTRunner(name, asm, - global_size[::-1], local_size[::-1], - op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name, runtime_args={"binary": True}) + if DEBUG >= 4: + for tins in lang.ins: print(tins) + return global_size, local_size \ No newline at end of file diff --git a/extra/assembly/assembly_arm64.py b/extra/assembly/assembly_arm64.py new file mode 100644 index 0000000000..ccdeb9abf6 --- /dev/null +++ b/extra/assembly/assembly_arm64.py @@ -0,0 +1,171 @@ +import struct +from platform import system +from extra.assembly.assembly import uops_to_asmstyle, AssemblyLanguage, Register +from typing import Tuple, Set, Dict, List +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps +from tinygrad.codegen.linearizer import UOps, ConstOp, UOp +from tinygrad.helpers import dtypes, CI + +def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) +def compute_offsets(total): + quotient, remainder = divmod(total, 4096) + return [4096]*quotient + [remainder] if remainder else [4096]*quotient + +#NOTE: Darwin needs names to start with a "_" +def get_name(name): return ('_' if system() == 'Darwin' else '') + name + +class ARM64Language(AssemblyLanguage): pass + +def specialize_to_arm64(fn_nm, asm): + var_size = 16 + prev_uop = None + ins = [] + x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)] + s_regs = ['s' + str(i) for i in reversed(range(3,30))] + type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} + alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max", + BinaryOps.MOD: "", BinaryOps.CMPLT: "subs", + UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"), + TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"} + + def mov_imm(value, reg): + # Manually move value into reg if value can't fit + if value.__class__ is not float and abs(value) > abs(65535): + ins.append(f"movz w15, #{value & 0xffff}") + ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16") + ins.append(f"sxtw {reg}, w15") + elif reg[0] == 's': + ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}") + ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16") + ins.append(f"str x15, [sp, 16]") + ins.append(f"ldr {reg}, [sp, 16]") + else: + ins.append(f"mov {reg}, #{value}") + + # Get variables intervals + live_range:Dict[str, str] = {} + for i, (uop, out, vin, arg) in enumerate(asm): + for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]): + live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i] + + mem_vars:Dict[str, str] = {} + rtor:Dict[str, str] = {} + def allocate_regs(vars): + nonlocal var_size + for v in [v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]: + available_regs = s_regs if dtypes.is_float(v[1]) else x_regs + #NOTE: Very simple spill, everything that don't fit in regs goes to mem + if len(available_regs) == 0: + # ARM needs the stack 16-byte aligned + var_size += 16 + available_regs.append('s0' if dtypes.is_float(out[1]) else 'x11') + mem_vars[v.nm] = var_size + rtor[v.nm] = available_regs.pop() + + temp_floats = ['s0', 's1', 's2'] + temp_ints = ['x11', 'x12', 'x13'] + for i, (uop, out, vin, arg) in enumerate(asm): + # Clear regs out of interval + for var, reg in list(rtor.items()): + available_regs = s_regs if reg[0] == 's' else x_regs + if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]: + available_regs.append(rtor.pop(var)) + # Assign a registers to the variables using live ranges. + allocate_regs([out] + vin) + # Assign temp regs to vin and load them before direct use + for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]): + rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i] + # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912 + ins.append(f"mov x15, {mem_vars[v.nm]}") + ins.append(f"ldr {rtor[v.nm]}, [sp, x15]") + + if uop == UOps.SPECIAL: + if arg.startswith('data'): + # data 8 to n into the stack + if int(arg[4:]) >= 8: + ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]") + ins.append(f"mov {rtor[out.nm]}, x15") + else: + ins.append(f"mov {rtor[out.nm]}, #0") + ins.append(f"loop_{arg}:") + elif uop == UOps.CAST: + if arg == BinaryOps.CMPLT: + mov_imm(0.0, 's0') + mov_imm(1.0, 's1') + ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt") + else: + ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}") + elif uop == UOps.ALU: + if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15') + if arg == BinaryOps.MUL and out.dtype == dtypes.bool: + ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") + elif arg == TernaryOps.WHERE: + ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0") + ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne") + elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]: + #NOTE: Not a real instruction, use to emulate a ext call in unicorn + if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}") + else: + save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars] + ins.append(f"sub sp, sp, #{(len(save_regs))*16}") + # Save the registers before they are cleared by func call + for i,k in enumerate(save_regs,1): + ins.append(f"str {rtor[k]}, [sp, #{16*i}]") + ins.append("stp x29, x30, [sp, #0]!") + ins.append("mov x29, sp") + ins.append(f"fmov s0, {rtor[vin[0].nm]}") + ins.append(alu[arg]) + ins.append(f"fmov {rtor[out.nm]}, s0") + ins.append("mov sp, x29") + ins.append("ldp x29, x30, [sp], #0") + for i,k in enumerate(save_regs,1): + ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]") + ins.append(f"add sp, sp, #{len(save_regs)*16}") + elif arg == BinaryOps.CMPLT: + ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}") + elif arg == BinaryOps.MOD: + ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15") + ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}") + else: + ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") + elif uop == UOps.LOAD: + if arg.__class__ in (int, float): + mov_imm(arg, rtor[out.nm]) + else: + #NOTE: if need casting load var in s/h0 or x/w12 temp regs + reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm] + mov_imm(arg[0], "x15") + ins.append(f"add x15, {rtor[vin[0].nm]}, x15") + ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]") + if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] == dtypes.half else 'scvtf'} {rtor[out.nm]}, {reg_in}") + elif uop == UOps.STORE: + shifts = {dtypes.int64: "#3", dtypes.half: "#1", dtypes.int8:"#2", dtypes.uint8: "#2", dtypes.bool: "#2"} + #NOTE: if need casting load var in s/h0 or x/w12 temp regs + reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm]) + if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] != dtypes.half else '' } {reg_out}, {rtor[vin[1].nm]}") + ins.append(f"mov x15, #{arg[0]}") + ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl {shifts[arg[2]] if arg[2] is not None and arg[2] in shifts else '#0'}]") + elif uop == UOps.COND_BRANCH: + #TODO: this is a hack it shouldn't always be a cmp before a cond branch? + if prev_uop == UOps.LOAD: + ins.append(f"cmp {rtor[vin[0].nm]}, #0") + ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}") + elif uop == UOps.LABEL: + ins.append(f"{arg[1:]}:") + elif uop == UOps.ENDLOOP: + mov_imm(arg[0], "x15") + ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1") + ins.append(f"cmp {rtor[vin[0].nm]}, x15") + ins.append(f"b.lt loop_{arg[1]}") + + prev_uop=uop + # store regs into memory if needed + if out is not None and out.nm in mem_vars: + ins.append(f"mov x15, {mem_vars[out.nm]}") + ins.append(f"str {rtor[out.nm]}, [sp, x15]") + return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x19, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"]) + +def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]: + lang = ARM64Language() + global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops) + return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True \ No newline at end of file diff --git a/setup.py b/setup.py index 826b63fe7c..33f1e65741 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ setup(name='tinygrad', extras_require={ 'llvm': ["llvmlite"], 'cuda': ["pycuda"], + 'arm': ["unicorn"], 'triton': ["triton>=2.0.0.dev20221202"], 'webgpu': ["wgpu"], 'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"], diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 233bfdad53..c009119e01 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -11,7 +11,7 @@ from tinygrad.helpers import dtypes, prod from tinygrad.runtime.lib import RawBuffer class FakeProgram: - def __init__(self, name:str, prg:str): pass + def __init__(self, name:str, prg:str, binary:bool): pass def __call__(self, global_size, local_size, *bufs, wait=False): pass class RawFakeBuffer(RawBuffer): diff --git a/test/test_uops.py b/test/test_uops.py index 8771d88bd0..766ad2e31f 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -1,14 +1,15 @@ import unittest, math import numpy as np -from tinygrad.helpers import dtypes +from tinygrad.helpers import dtypes, getenv from tinygrad.tensor import Device from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp from tinygrad.shape.symbolic import Variable def _uops_to_prg(uops): - src, global_size, local_size = Device[Device.DEFAULT].renderer("test", uops) - return ASTRunner("test", src, global_size, local_size).build(Device[Device.DEFAULT].runtime) + ret = Device[Device.DEFAULT].renderer("test", uops) + src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,) + return ASTRunner("test", src, global_size, local_size, runtime_args={"binary": binary}).build(Device[Device.DEFAULT].runtime) def _test_single_value(tc, tt, vals, op): uops = [ diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c49a5e3d9d..aba53b2ce6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -156,10 +156,12 @@ class Compiled: def to_program(self, k): k.linearize() - src, global_size, local_size = self.renderer(k.function_name, k.uops) + ret = self.renderer(k.function_name, k.uops) + src, global_size, local_size, binary = ret if len(ret) == 4 else ret + (False,) + #TODO: I need to find a better way to select ARM64 return ASTRunner(k.function_name, src, global_size, local_size, op_estimate=k.info.flops, mem_estimate=k.mem_estimate, - display_name=k.display_name).build(self.runtime) + display_name=k.display_name, runtime_args={"binary": binary}).build(self.runtime) def exec_ast(self, ast:LazyOp, output, **kwargs): # all movementops do nothing in a Compiled buffer! diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index b6f7fb90a8..447edd5016 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,8 +1,15 @@ import os, time, ctypes, hashlib, subprocess, platform, tempfile, functools +from functools import partial, reduce from tinygrad.ops import Compiled +from tinygrad.helpers import fromimport, getenv, DEBUG, CI from tinygrad.runtime.lib import RawMallocBuffer from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage +import struct +import numpy as np + +ARM64 = getenv('ARM64', False) +if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const args = { 'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'}, @@ -11,24 +18,64 @@ args = { }[platform.system()] CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n' +ADDRESS = 0x10000 + +# Unicorn doesn't support external calls +def align(addr): return (addr+4095) & ~(4095) +mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2} +def emulate_ext_calls(fn, uc, address, size, user_data): + s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0] + uc.reg_write(getattr(arm64_const, f'UC_ARM64_REG_S{fn[1][1:]}'), struct.unpack('I', struct.pack('f', mock_lm[fn[0]](s_in)))[0]) # type: ignore + class ClangProgram: - def __init__(self, name:str, prg:str): - prg = CLANG_PROGRAM_HEADER + prg + def __init__(self, name:str, prg:str, binary:bool=False): # TODO: is there a way to not write this to disk? # A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file # because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}" if not os.path.exists(fn): tmp = f"{fn}.{os.getpid()}.tmp" - subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8')) - os.rename(tmp, fn) + if not binary: + prg = CLANG_PROGRAM_HEADER + prg + subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8')) + os.rename(tmp, fn) + else: + if DEBUG >= 5: print(prg) + if CI and ARM64: + prg = prg.split('\n') # type: ignore + self.varsize = align(int(prg[0].split(" ")[1])) + self.ext_calls = {(i*4+ADDRESS):ins.split(" ")[1:] for i, ins in enumerate(filter(lambda ins: ins[:4] != 'loop', prg[6:-3])) if ins[:2] == 'bl'} + prg = "\n".join(['nop' if ins[:2] == 'bl' else ins for ins in prg[6:-3]] + ['\n']) + subprocess.check_output(args=('aarch64-linux-gnu-as -o '+tmp).split(), input=prg.encode('utf-8')) + subprocess.check_output(args=('aarch64-linux-gnu-objcopy -O binary --only-section=.text '+tmp+' '+fn+'.bin').split()) + self.prg = open(fn + '.bin', 'rb').read() + return + subprocess.check_output(args=('as -o' + tmp).split(), input=prg.encode('utf-8')) + subprocess.check_output(args=('clang -lm -shared '+tmp+' -o'+fn).split()) self.lib = ctypes.CDLL(fn) self.fxn = self.lib[name] - def __call__(self, global_size, local_size, *args, wait=False): if wait: st = time.monotonic() - self.fxn(*[x._buf for x in args]) + if CI and ARM64: + mu = Uc(UC_ARCH_ARM64, UC_MODE_ARM) + total_mem = align(reduce(lambda total, arg: total + arg.size * arg.dtype.itemsize, args, len(self.prg)+self.varsize)) + mu.mem_map(ADDRESS, total_mem) + for k, fn in self.ext_calls.items(): mu.hook_add(UC_HOOK_CODE, partial(emulate_ext_calls, fn), begin=k, end=k) + mu.mem_write(ADDRESS, self.prg + b''.join(bytes(arg._buf) for arg in args)) + addr = ADDRESS + len(self.prg) + for i, arg in enumerate(args): + if i<=7: + mu.reg_write(getattr(arm64_const, f'UC_ARM64_REG_X{i}'), addr) + else: + # NOTE: In ARM, args beyond the first 8 are placed on the stack it also account for the stack red zone. + mu.mem_write(ADDRESS + total_mem - (len(args[8:])+2)*8 + 8*(i-8), addr.to_bytes(8, 'little')) + addr += arg.size * arg.dtype.itemsize + mu.reg_write(arm64_const.UC_ARM64_REG_SP, ADDRESS + total_mem - (len(args[8:])+2)*8) + mu.emu_start(ADDRESS, ADDRESS + len(self.prg)) + args[0]._buf = mu.mem_read(mu.reg_read(arm64_const.UC_ARM64_REG_X0), args[0].size * args[0].dtype.itemsize) + else: + self.fxn(*[x._buf for x in args]) if wait: return time.monotonic()-st -renderer = functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict")) +renderer = fromimport("extra.assembly.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict")) ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 918878c0ee..4c5e564d17 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -37,7 +37,7 @@ def unwrap(x): return ret class MetalProgram: - def __init__(self, name:str, prg:str): + def __init__(self, name:str, prg:str, binary:bool=False): if METAL_XCODE: air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8')) # NOTE: if you run llvm-dis on "air" you can see the llvm bytecode From 666ac61070160479763e206a0b428f78b93f1854 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Tue, 15 Aug 2023 01:39:57 -0400 Subject: [PATCH 6/8] support for p2p buffer transfers (#1523) * feat: RawBufferTransfer * feat: gate behind P2P * feat: gate properly * feat: raise error when not implemented --- tinygrad/lazy.py | 7 +++++-- tinygrad/runtime/lib.py | 9 +++++++++ tinygrad/runtime/ops_gpu.py | 8 ++++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 99960ad2d1..ceb4d4e12d 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -11,7 +11,7 @@ from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, View, get_contraction from tinygrad.shape.symbolic import Node from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, TernaryOps, ReduceOps, LoadOps, OpType, LazyOp -from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer +from tinygrad.runtime.lib import RawBufferMapped, RawConst, RawBuffer, RawBufferTransfer # lazy can recurse a lot sys.setrecursionlimit(10000) @@ -19,6 +19,7 @@ sys.setrecursionlimit(10000) OPT = getenv("OPT", 2) LAZY = getenv("LAZY", 1) LAZYCACHE = getenv("LAZYCACHE", 1) +P2P = getenv("P2P", 0) # TODO: movement ops that only change shape are really nops. treat them as such REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1 @@ -380,6 +381,8 @@ def _realize_from(buffer: LazyBuffer) -> None: if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped): buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args()) rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer()) + elif isinstance(rawbuf.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1: + buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(rawbuf.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args()) else: buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args()) @@ -412,4 +415,4 @@ MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = { MovementOps.PERMUTE: LazyBuffer.permute, MovementOps.PAD: LazyBuffer.pad, MovementOps.STRIDE: LazyBuffer.stride, -} \ No newline at end of file +} diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 6b32ffa937..f930d5b5a5 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -50,6 +50,15 @@ class RawBufferCopyInOut(RawBufferCopyIn): self._copyout(x) return x +class RawBufferTransfer(RawBuffer): + def _transfer(self, x) -> None: raise NotImplementedError("must be implemented") + + @classmethod + def transfer(cls, x, shape, dtype, **kwargs): + ret = cls(prod(shape), dtype, **kwargs) + ret._transfer(x) + return ret + class RawConst(RawBuffer): # pylint: disable=abstract-method def __repr__(self): return f"const<{self._buf}, {self.dtype}>" @property diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index ea18ef6788..90eb237e14 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -5,7 +5,7 @@ import pyopencl as cl # type: ignore from typing import Optional, List from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferCopyInOut +from tinygrad.runtime.lib import RawBufferCopyInOut, RawBufferTransfer from tinygrad.codegen.linearizer import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -29,7 +29,7 @@ class _CL: CL = _CL() CL.post_init() if not getenv("DELAYED_RUNTIME_INIT", False) else None -class CLBuffer(RawBufferCopyInOut): +class CLBuffer(RawBufferCopyInOut, RawBufferTransfer): def __init__(self, size, dtype, device='0'): if isinstance(dtype, ImageDType): fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize]) @@ -49,6 +49,10 @@ class CLBuffer(RawBufferCopyInOut): buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data) mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False) with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else [])) + def _transfer(self, x): + if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name: + cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait() + else: raise NotImplementedError("p2p transfer between devices not implemented on non-amd") class CLProgram: def __init__(self, name:str, prg:str, binary=False, argdtypes=None, options=None): From f8109b830c785be13160d942997e3dc567282fc6 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 14 Aug 2023 22:47:45 -0700 Subject: [PATCH 7/8] promote assembly to the main codebase (#1544) * promote assembly to the main codebase * not namedtuple --- .../assembly => tinygrad/codegen}/assembly.py | 16 ++++----- .../codegen}/assembly_arm64.py | 33 +++++++++---------- tinygrad/runtime/ops_clang.py | 10 +++--- tinygrad/runtime/ops_webgpu.py | 2 +- 4 files changed, 30 insertions(+), 31 deletions(-) rename {extra/assembly => tinygrad/codegen}/assembly.py (96%) rename {extra/assembly => tinygrad/codegen}/assembly_arm64.py (92%) diff --git a/extra/assembly/assembly.py b/tinygrad/codegen/assembly.py similarity index 96% rename from extra/assembly/assembly.py rename to tinygrad/codegen/assembly.py index d2b9fedb0c..9ab45070bc 100644 --- a/extra/assembly/assembly.py +++ b/tinygrad/codegen/assembly.py @@ -1,7 +1,7 @@ -from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict -from tinygrad.codegen.linearizer import Linearizer, UOps, Token, ConstOp, MemOp, UOp +from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast +from tinygrad.codegen.linearizer import UOps, Token, ConstOp, MemOp, UOp from tinygrad.ops import BinaryOps, UnaryOps -from tinygrad.helpers import DType, dtypes, DEBUG, getenv +from tinygrad.helpers import DType, dtypes, DEBUG from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode import functools import math @@ -28,14 +28,14 @@ class AssemblyInstruction(NamedTuple): arg: Any = None # warp size of 32, s registers are shared across the warp, v are 32-wide vectors -class AssemblyLanguage(NamedTuple): +class AssemblyLanguage: supports_load3: bool = False sin_is_sin2pi: bool = False no_div: bool = False #TODO: these should be global vars cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int) tor: Dict[Any, Register] = {} - ins = [] + ins: List[AssemblyInstruction] = [] def newreg(self, tok, dtype=dtypes.float32, scalar=False): if isinstance(tok, Token): dtype = tok.dtype # this @@ -65,7 +65,7 @@ class AssemblyLanguage(NamedTuple): self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])) return self.tor[key] - render_ops = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b), + render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b), MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b), DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b), ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b), @@ -81,7 +81,7 @@ class AssemblyLanguage(NamedTuple): nums = [n.b for n in idx.nodes if isinstance(n, NumNode)] if len(nums) > 0 and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU? idx -= nums[0] - off = nums[0] + off = cast(int, nums[0]) reg = idx.render(self.render_ops, self) if self.supports_load3: if reg.scalar: @@ -93,7 +93,7 @@ class AssemblyLanguage(NamedTuple): return reg, None, off def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): - #TODO: Do not use clear() + #TODO: Do not use clear() lang.ins.clear() lang.tor.clear() buf_to_dtype = {args[0]:args[1] for uop,_,_,args in uops if uop == UOps.DEFINE_GLOBAL} diff --git a/extra/assembly/assembly_arm64.py b/tinygrad/codegen/assembly_arm64.py similarity index 92% rename from extra/assembly/assembly_arm64.py rename to tinygrad/codegen/assembly_arm64.py index ccdeb9abf6..8f4e824215 100644 --- a/extra/assembly/assembly_arm64.py +++ b/tinygrad/codegen/assembly_arm64.py @@ -1,25 +1,25 @@ import struct from platform import system -from extra.assembly.assembly import uops_to_asmstyle, AssemblyLanguage, Register -from typing import Tuple, Set, Dict, List +from typing import Tuple, Dict, List, Optional from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps -from tinygrad.codegen.linearizer import UOps, ConstOp, UOp +from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.helpers import dtypes, CI +from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) def compute_offsets(total): quotient, remainder = divmod(total, 4096) return [4096]*quotient + [remainder] if remainder else [4096]*quotient -#NOTE: Darwin needs names to start with a "_" +#NOTE: Darwin needs names to start with a "_" def get_name(name): return ('_' if system() == 'Darwin' else '') + name class ARM64Language(AssemblyLanguage): pass def specialize_to_arm64(fn_nm, asm): var_size = 16 - prev_uop = None - ins = [] + prev_uop:Optional[UOps] = None + ins = [] x_regs = ['x' + str(i) for i in reversed(range(29)) if i not in (10,11,12,13,14,15,16,17,18,19,20)] s_regs = ['s' + str(i) for i in reversed(range(3,30))] type_to_reg = {dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'} @@ -37,22 +37,22 @@ def specialize_to_arm64(fn_nm, asm): elif reg[0] == 's': ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}") ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16") - ins.append(f"str x15, [sp, 16]") + ins.append("str x15, [sp, 16]") ins.append(f"ldr {reg}, [sp, 16]") else: ins.append(f"mov {reg}, #{value}") # Get variables intervals - live_range:Dict[str, str] = {} + live_range:Dict[str, List[int]] = {} for i, (uop, out, vin, arg) in enumerate(asm): for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]): live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i] - mem_vars:Dict[str, str] = {} + mem_vars:Dict[str, int] = {} rtor:Dict[str, str] = {} - def allocate_regs(vars): + def allocate_regs(mvars): nonlocal var_size - for v in [v for v in vars if v is not None and v.__class__ is not int and v.nm not in rtor]: + for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]: available_regs = s_regs if dtypes.is_float(v[1]) else x_regs #NOTE: Very simple spill, everything that don't fit in regs goes to mem if len(available_regs) == 0: @@ -72,7 +72,7 @@ def specialize_to_arm64(fn_nm, asm): available_regs.append(rtor.pop(var)) # Assign a registers to the variables using live ranges. allocate_regs([out] + vin) - # Assign temp regs to vin and load them before direct use + # Assign temp regs to vin and load them before direct use for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]): rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i] # ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912 @@ -81,7 +81,7 @@ def specialize_to_arm64(fn_nm, asm): if uop == UOps.SPECIAL: if arg.startswith('data'): - # data 8 to n into the stack + # data 8 to n into the stack if int(arg[4:]) >= 8: ins.append(f"ldr x15, [x19, #{(int(arg[4:]) - 8) * 8}]") ins.append(f"mov {rtor[out.nm]}, x15") @@ -108,7 +108,7 @@ def specialize_to_arm64(fn_nm, asm): else: save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars] ins.append(f"sub sp, sp, #{(len(save_regs))*16}") - # Save the registers before they are cleared by func call + # Save the registers before they are cleared by func call for i,k in enumerate(save_regs,1): ins.append(f"str {rtor[k]}, [sp, #{16*i}]") ins.append("stp x29, x30, [sp, #0]!") @@ -157,9 +157,8 @@ def specialize_to_arm64(fn_nm, asm): ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1") ins.append(f"cmp {rtor[vin[0].nm]}, x15") ins.append(f"b.lt loop_{arg[1]}") - - prev_uop=uop - # store regs into memory if needed + prev_uop = uop + # store regs into memory if needed if out is not None and out.nm in mem_vars: ins.append(f"mov x15, {mem_vars[out.nm]}") ins.append(f"str {rtor[out.nm]}, [sp, x15]") diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 447edd5016..f2cbe5af23 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -9,7 +9,7 @@ import struct import numpy as np ARM64 = getenv('ARM64', False) -if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const +if CI and ARM64: from unicorn import Uc, UC_ARCH_ARM64, UC_MODE_ARM, UC_HOOK_CODE, arm64_const # type: ignore args = { 'Windows': {'cflags':'', 'ext':'dll', 'exp':'__declspec(dllexport)'}, @@ -20,8 +20,8 @@ args = { CLANG_PROGRAM_HEADER = '#include \n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#define bool uchar\n' ADDRESS = 0x10000 -# Unicorn doesn't support external calls -def align(addr): return (addr+4095) & ~(4095) +# Unicorn doesn't support external calls +def align(addr): return (addr+4095) & ~(4095) mock_lm = {"sinf": np.sin, "sqrtf": np.sqrt, "exp2f": np.exp2, "log2f": np.log2} def emulate_ext_calls(fn, uc, address, size, user_data): s_in = struct.unpack('f', struct.pack('I', uc.reg_read(getattr(arm64_const, f'UC_ARM64_REG_S{fn[2][1:]}'))))[0] @@ -33,6 +33,7 @@ class ClangProgram: # A: it seems there isn't https://stackoverflow.com/questions/28053328/ctypes-cdll-load-library-from-memory-rather-than-file # because ctypes.CDLL() calls dlopen (POSIX) or LoadLibrary (Windows) which require a file fn = f"{tempfile.gettempdir()}/clang_{hashlib.md5(prg.encode('utf-8')).hexdigest()}.{args['ext']}" + if binary and DEBUG >= 5: print(prg) if not os.path.exists(fn): tmp = f"{fn}.{os.getpid()}.tmp" if not binary: @@ -40,7 +41,6 @@ class ClangProgram: subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+tmp).split(), input=prg.encode('utf-8')) os.rename(tmp, fn) else: - if DEBUG >= 5: print(prg) if CI and ARM64: prg = prg.split('\n') # type: ignore self.varsize = align(int(prg[0].split(" ")[1])) @@ -77,5 +77,5 @@ class ClangProgram: self.fxn(*[x._buf for x in args]) if wait: return time.monotonic()-st -renderer = fromimport("extra.assembly.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict")) +renderer = fromimport("tinygrad.codegen.assembly_arm64", "uops_to_arm64_asm") if ARM64 else functools.partial(uops_to_cstyle, CStyleLanguage(kernel_prefix=args['exp'], buffer_suffix=" restrict")) ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram) diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index c3cd95d20e..55574f6ef4 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -12,7 +12,7 @@ import wgpu # type: ignore device = get_default_device() class WebGPUProgram: - def __init__(self, name: str, prg: str): self.name,self.prg = name,device.create_shader_module(code=prg) + def __init__(self, name: str, prg: str, binary=False): self.name,self.prg = name,device.create_shader_module(code=prg) def __call__(self, global_size, local_size, *bufs, wait=False): assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))] From 0b5930d406e55ffd3cfd10f1386241f2770b3fe6 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 15 Aug 2023 09:07:26 -0700 Subject: [PATCH 8/8] more uops testing, who isn't passing right now... (#1522) * more uops * llvm refactor * update test uops * rest of the nodes * ors and ands --- test/test_uops.py | 53 ++++++++++++++++++++++------------- tinygrad/ops.py | 2 +- tinygrad/renderer/cstyle.py | 2 +- tinygrad/renderer/llvmir.py | 54 +++++++++++++++++++++--------------- tinygrad/runtime/ops_llvm.py | 3 +- 5 files changed, 69 insertions(+), 45 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index 766ad2e31f..4f8bcf8d54 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -37,36 +37,20 @@ def _test_single_value_const(tc, tt, vals, op): prg([buf]) return buf.toCPU()[0] -@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends") class TestUOps(unittest.TestCase): def _equal(self, v1, v2): if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5) def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32): for f in [_test_single_value, _test_single_value_const]: - for a in [-2.0, 2.0]: + for a in [-2.0, 0.0, 1.0, 2.0]: self._equal(f(Token('c', dt), [Token('a', dt)], [a], bop), fxn(a)) - def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a)) - def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('nan')) - def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) - def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) - #def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a) - def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32): + def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False): for f in [_test_single_value, _test_single_value_const]: - for a in [-2.0, 2.0]: - for b in [-3.0, 3.0]: + for a in [-2.0, 0.0, 1.0, 2.0]: + for b in [-3.0, 1.0, 3.0] + ([] if no_b_zero else [0.0]): self._equal(f(Token('c', dt), [Token('a', dt), Token('b', dt)], [a,b], bop), fxn(a,b)) - def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) - def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b) - def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) - def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b) - def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) - def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a 0 else float('-inf' if a==0 else 'nan')) + def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) + def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) + # this is not on most backends + #def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf')) + + def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) + def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b) + def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) + def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf')) + def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) + def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a FlopCounter: return InterpretedFlopCounter.ex class ASTRunner: def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None): - if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg) + if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg) self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} def build(self, runtime): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 756160d75d..1eaa0b433d 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -35,7 +35,7 @@ class CStyleLanguage(NamedTuple): UnaryOps.SQRT: lambda x: f"sqrt({x})", BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", - BinaryOps.MAX: lambda a,b: f"max({a},{b})", + BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})", BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})" } diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index cf017b6949..5c00734645 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -32,6 +32,35 @@ code_for_op: Final[Dict[Op, Callable]] = { TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)), } +dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} + +def cast(bb, val, input_type, output_type): + if input_type == output_type: return val + + if output_type == dtypes.float32: + if dtypes.is_int(input_type) or input_type == dtypes.bool: + val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType()) + elif input_type == dtypes.bfloat16: + val = bb[-1].sext(val, ir.IntType(32)) + val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16)) + val = bb[-1].bitcast(val, ir.FloatType()) + else: + val = bb[-1].fpext(val, ir.FloatType()) + return val + + if input_type == dtypes.float32: + if dtypes.is_int(output_type) or output_type == dtypes.bool: + val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_unsigned(output_type) or output_type == dtypes.bool else bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type]) + elif output_type == dtypes.bfloat16: + val = bb[-1].bitcast(val, ir.IntType(32)) + val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16)) + val = bb[-1].trunc(val, ir.IntType(16)) + else: + val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type]) + return val + + raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") + def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]: # all llvm stuff goes into a module module = ir.Module(name=__file__) @@ -41,7 +70,6 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} # create llvm function - dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)} func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()] func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name) @@ -84,9 +112,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block) if uop == UOps.LOAD: assert newvar is not None and isinstance(args, (MemOp, ConstOp)) - assert newvar.dtype == dtypes.float, "newvar must be float" valid = args.valid.render(render_llvm, bb[-1]) if isinstance(args, ConstOp): + assert newvar.dtype == dtypes.float, "newvar must be float" if args.valid.min == 0 and args.valid.max == 1: val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value)) else: @@ -100,30 +128,12 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value)) else: val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True)) - - if args.memory_dtype != newvar.dtype: - if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool: - val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].sitofp(val, ir.FloatType()) - elif args.memory_dtype == dtypes.bfloat16: - val = bb[-1].sext(val, ir.IntType(32)) - val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16)) - val = bb[-1].bitcast(val, ir.FloatType()) - else: - val = bb[-1].fpext(val, ir.FloatType()) + val = cast(bb, val, args.memory_dtype, newvar.dtype) lvars[newvar] = val if uop == UOps.STORE: assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory" idx = args.idx.render(render_llvm, bb[-1]) - element = lvars[vin[0]] - if args.memory_dtype != vin[0].dtype: - if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool: - element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype]) - elif args.memory_dtype == dtypes.bfloat16: - element = bb[-1].bitcast(element, ir.IntType(32)) - element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16)) - element = bb[-1].trunc(element, ir.IntType(16)) - else: - element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype]) + element = cast(bb, lvars[vin[0]], vin[0].dtype, args.memory_dtype) bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True)) if uop == UOps.ALU: lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin]) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index ab0f8b9a26..00d404a799 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -55,7 +55,8 @@ class LLVMProgram: LLVM.engine.finalize_object() self.fxn = LLVM.engine.get_function_address(name) - def __del__(self): LLVM.engine.remove_module(self.mod) + def __del__(self): + if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod) def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False): cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)