diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 5a90ef0522..9a7391db17 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -471,7 +471,7 @@ class Group: idxs = tuple(idx * rv.length if i == 3 else idx for i, idx in enumerate(idxs)) dst_i = ((idxs[0] * dst.shape[-3] + idxs[1]) * dst.shape[-2] + idxs[2]) * dst.shape[-1] + idxs[3] - for outer in self.ker.range(src.shape[-2]): + for outer in self.ker.range(src.shape[-2], track=False): dst_i += outer * reductions + (laneid % reductions) src_load = src[outer, 0]