mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix bugs
This commit is contained in:
@@ -52,7 +52,7 @@ def flip_contract_kernel(dest:UOp, src:UOp):
|
||||
j = UOp.range(dest.shape[1], 1, AxisType.UPCAST)
|
||||
vec = src[i, j].contract(j)
|
||||
store = UOp.group(*[dest[i, k].store(vec.gep(3-k)) for k in range(4)])
|
||||
return store.end(i).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
||||
return store.end(i, j).sink(arg=KernelInfo(name=f"flip_contract_{dest.size}", opts_to_apply=()))
|
||||
|
||||
def slice_sum_kernel(dest:UOp, src:UOp):
|
||||
G = UOp.range(src.shape[0], 0)
|
||||
|
||||
@@ -311,6 +311,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def ended_ranges(self):
|
||||
if self.op in range_start: return self.src[range_start[self.op]:]
|
||||
if self.op is Ops.AFTER: return tuple(flatten([x.ended_ranges for x in self.src[1:]]))
|
||||
# TODO: copy isn't using range properly and isn't ending the range it uses, remove this
|
||||
if self.op in {Ops.COPY, Ops.BUFFER_VIEW}: return self.src[0].ranges
|
||||
return ()
|
||||
|
||||
# determine what ranges this is in
|
||||
@@ -820,6 +822,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
|
||||
Reference in New Issue
Block a user