This commit is contained in:
George Hotz
2023-06-27 20:28:57 -07:00
committed by GitHub
parent d16c16ec28
commit 9fabdbd054

View File

@@ -545,10 +545,12 @@ class Linearizer:
else:
break
# if last dim <= 16 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
if self.full_unupcasted_shape[-1] <= 16:
if (s:=self.full_unupcasted_shape[-1]) <= 32:
self.upcast()
# if it's small, upcast a second reduce dimension too
if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and self.full_unupcasted_shape[-1] <= 3: self.upcast()
else:
for splits in [4]:
if self.full_unupcasted_shape[-1]%splits == 0: