Fix gpt optimization (#1885)

* fix for gpt

* the actual fix

* Remove change in symbolic

* small comment
This commit is contained in:
Gijs Koning
2023-09-21 04:28:18 +02:00
committed by GitHub
parent bd3444797b
commit 9eb6310686

View File

@@ -337,10 +337,6 @@ class OptimizedKernel(Kernel):
# no more opt if we are grouping
if self.group_for_reduce: return
# no more opt if there's non ints in any shapes
# TODO: this is due to a bug. repro by commenting this one while running GPT-2 with the JIT
if self.has_variable_shape(): return
# **** below this line need to be optional and benchmarked ****
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
@@ -349,7 +345,8 @@ class OptimizedKernel(Kernel):
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in range(self.first_reduce):
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
if self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and prod(self.full_shape[self.shape_len - self.upcasted:]) * self.full_shape[axis] <= 7 * 7:
# for now skip upcasting here if there is a symbolic axis
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and prod(self.full_shape[self.shape_len - self.upcasted:]) * self.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]:
@@ -408,4 +405,4 @@ class OptimizedKernel(Kernel):
for axis, local_sz in sorted(to_local[:3]):
self.shift_to(axis, local_sz, insert_before=self.first_reduce)
self.local_dims += 1
self.simplify_ones()
self.simplify_ones()