free up lines for const_arg [pr] (#8083)

This commit is contained in:
qazal
2024-12-06 16:28:51 +02:00
committed by GitHub
parent ba35c4138b
commit 79966fade0

View File

@@ -106,23 +106,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def push_swizzle_down_through_reduce(r:UOp, v:UOp, src:UOp) -> UOp:
swizzle_st, src_st = unwrap(v.st), unwrap(src.st)
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
output_shape = swizzle_st.reduce(r.axis_arg)
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
return src.r(r.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
if not (swizzles := [x for x in root.src if x.base is not x]): return None
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
new_shape, new_input_shape = swizzle_shapes[0]
new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src)
ret = root.replace(src=new_src)
assert all_same([(x.shape, prod(x.src[0].shape)) for x in swizzles]), f"swizzles must have the same size {swizzles}"
new_input_st = ShapeTracker.from_shape(swizzles[0].src[0].shape)
ret = root.replace(src=tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, new_input_st) for x in root.src))
# update the ASSIGN offset to match the new shape
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+ShapeTracker.from_shape(new_input_shape),)
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+new_input_st,)
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(swizzles[0].shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"