This commit is contained in:
George Hotz
2026-02-06 14:10:33 +08:00
parent 85ee86f2fb
commit 7082b76775
2 changed files with 4 additions and 1 deletions

View File

@@ -52,7 +52,7 @@ def flip_contract_kernel(dest:UOp, src:UOp):
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
vec = src[i, j].contract(j)
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
return store.end(i, j).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
def slice_sum_kernel(dest:UOp, src:UOp):
G = UOp.range(src.shape[0], 0)

View File

@@ -311,6 +311,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def ended_ranges(self):
if self.op in range_start: return self.src[range_start[self.op]:]
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
# TODO: copy isn't using range properly and isn't ending the range it uses, remove this
if self.op in {Ops.COPY, Ops.BUFFER_VIEW}: return self.src[0].ranges
return ()
# determine what ranges this is in
@@ -820,6 +822,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call"
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)