diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 4145e88589..fa850fee6a 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -545,10 +545,12 @@ class Linearizer: else: break - # if last dim <= 16 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS + # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))): - if self.full_unupcasted_shape[-1] <= 16: + if (s:=self.full_unupcasted_shape[-1]) <= 32: self.upcast() + # if it's small, upcast a second reduce dimension too + if self.first_reduce < (self.shape_len-self.upcasted) and s <= 3 and self.full_unupcasted_shape[-1] <= 3: self.upcast() else: for splits in [4]: if self.full_unupcasted_shape[-1]%splits == 0: