mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
Fix gpt optimization (#1885)
* fix for gpt * the actual fix * Remove change in symbolic * small comment
This commit is contained in:
@@ -337,10 +337,6 @@ class OptimizedKernel(Kernel):
|
||||
# no more opt if we are grouping
|
||||
if self.group_for_reduce: return
|
||||
|
||||
# no more opt if there's non ints in any shapes
|
||||
# TODO: this is due to a bug. repro by commenting this one while running GPT-2 with the JIT
|
||||
if self.has_variable_shape(): return
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
@@ -349,7 +345,8 @@ class OptimizedKernel(Kernel):
|
||||
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
|
||||
for axis in range(self.first_reduce):
|
||||
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
||||
if self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and prod(self.full_shape[self.shape_len - self.upcasted:]) * self.full_shape[axis] <= 7 * 7:
|
||||
# for now skip upcasting here if there is a symbolic axis
|
||||
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and prod(self.full_shape[self.shape_len - self.upcasted:]) * self.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
for axis in to_upcast[::-1]:
|
||||
@@ -408,4 +405,4 @@ class OptimizedKernel(Kernel):
|
||||
for axis, local_sz in sorted(to_local[:3]):
|
||||
self.shift_to(axis, local_sz, insert_before=self.first_reduce)
|
||||
self.local_dims += 1
|
||||
self.simplify_ones()
|
||||
self.simplify_ones()
|
||||
Reference in New Issue
Block a user