mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Kernel.apply_opts [pr] (#9917)
* Kernel.apply_opts [pr] updated all `for opt in`. also updated a few test_liinearizer tests to not implcitly depend on hand_coded_optimization * not you yet
This commit is contained in:
@@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
#Opt(op=OptOps.UPCAST, axis=1, arg=4),
|
||||
Opt(op=OptOps.LOCAL, axis=1, arg=LN),
|
||||
Opt(op=OptOps.LOCAL, axis=0, arg=LN)]
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program(ast_transform=ast_transform)
|
||||
if getenv("FAST", 1) and Device.DEFAULT == "AMD":
|
||||
#src = (pathlib.Path(__file__).parent / "fp32_sgemm_amd" / "src" / "kernel8_batched_gmem.s").read_text()
|
||||
|
||||
@@ -55,8 +55,7 @@ def randoms():
|
||||
def ast_to_cuda_prog(compiler, ast, opts):
|
||||
k = Kernel(ast)
|
||||
k.required_optimizations()
|
||||
for opt in opts:
|
||||
k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
p = k.to_program()
|
||||
return CUDAProgram(device, p.function_name, compiler.compile(p.src))
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
Opt(op=OptOps.LOCAL, axis=1, amt=2),
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=2),
|
||||
]
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
new_src = prg.src
|
||||
# can mod source here
|
||||
|
||||
@@ -51,7 +51,7 @@ def dataset_from_cache(fn):
|
||||
lin = Kernel(eval(ast))
|
||||
except Exception:
|
||||
continue
|
||||
for opt in k[:-1]: lin.apply_opt(opt)
|
||||
lin.apply_opts(k[:-1])
|
||||
act = k[-1]
|
||||
log_ratio = math.log(old_tm/new_tm)
|
||||
#print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}")
|
||||
|
||||
@@ -15,8 +15,7 @@ def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str
|
||||
def kern_str_to_lin(kern_str:str, opts=None):
|
||||
(ast, applied_opts,) = eval(kern_str)
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts:
|
||||
k.apply_opt(opt)
|
||||
k.apply_opts(applied_opts)
|
||||
return k
|
||||
|
||||
# load worlds, a dataset of about 12k kernels
|
||||
|
||||
@@ -32,7 +32,7 @@ ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=2)]
|
||||
|
||||
k = Kernel(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
bufs = bufs_from_lin(k)
|
||||
|
||||
prg = CompiledRunner(k.to_program())
|
||||
|
||||
13
test/external/external_test_hcq_fuzz_failures.py
vendored
13
test/external/external_test_hcq_fuzz_failures.py
vendored
@@ -18,13 +18,12 @@ def helper_test_lin(lin: Kernel, opts, failed_platforms, validate_device, rtol=1
|
||||
if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
|
||||
if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
|
||||
|
||||
for opt in opts:
|
||||
try:
|
||||
lin.apply_opt(opt)
|
||||
except KernelOptError:
|
||||
# it's considered fixed if we invalidated the opts
|
||||
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
|
||||
return
|
||||
try:
|
||||
lin.apply_opts(opts)
|
||||
except KernelOptError:
|
||||
# it's considered fixed if we invalidated the opts
|
||||
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
|
||||
return
|
||||
|
||||
(msg, rawbufs, var_vals, ground_truth, state1) = compare_linearizer(lin, rtol=rtol, atol=atol)
|
||||
if msg in ["PASS", "KernelOptError"]:
|
||||
|
||||
2
test/external/external_test_nv.py
vendored
2
test/external/external_test_nv.py
vendored
@@ -34,7 +34,7 @@ class TestNV(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=2)] # noqa: E501
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
lin = Kernel(ast)
|
||||
for opt in opts: lin.apply_opt(opt)
|
||||
lin.apply_opts(opts)
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
prg = CompiledRunner(lin.to_program())
|
||||
prg(rawbufs, {}, wait=True)
|
||||
|
||||
4
test/external/external_test_train_gpt2.py
vendored
4
test/external/external_test_train_gpt2.py
vendored
@@ -28,7 +28,7 @@ class TestTrainGpt2Kernel(unittest.TestCase):
|
||||
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=3), Opt(op=OptOps.LOCAL, axis=0, arg=2)]
|
||||
kernel = Kernel(ast)
|
||||
for opt in opts: kernel.apply_opt(opt)
|
||||
kernel.apply_opts(opts)
|
||||
run_linearizer(kernel)
|
||||
|
||||
def test_2(self):
|
||||
@@ -48,7 +48,7 @@ class TestTrainGpt2Kernel(unittest.TestCase):
|
||||
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=4)]
|
||||
kernel = Kernel(ast)
|
||||
for opt in opts: kernel.apply_opt(opt)
|
||||
kernel.apply_opts(opts)
|
||||
run_linearizer(kernel)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
4
test/external/external_test_valid_remove.py
vendored
4
test/external/external_test_valid_remove.py
vendored
@@ -54,7 +54,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)]
|
||||
kernel = Kernel(ast)
|
||||
|
||||
for opt in opts: kernel.apply_opt(opt)
|
||||
kernel.apply_opts(opts)
|
||||
|
||||
p = kernel.to_program()
|
||||
print(p.src)
|
||||
@@ -111,7 +111,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)]
|
||||
kernel = Kernel(ast)
|
||||
|
||||
for opt in opts: kernel.apply_opt(opt)
|
||||
kernel.apply_opts(opts)
|
||||
|
||||
p = kernel.to_program()
|
||||
# ((idx1<1)?read_imagef(data1, smp, (int2)(idx0,0)):(float4)(0.0f,0.0f,0.0f,0.0f))
|
||||
|
||||
3
test/external/verify_kernel.py
vendored
3
test/external/verify_kernel.py
vendored
@@ -36,8 +36,7 @@ if __name__ == "__main__":
|
||||
with open(args.pkl, 'rb') as file:
|
||||
(ast, applied_opts,) = pickle.load(file)
|
||||
lin = Kernel(ast)
|
||||
for opt in applied_opts:
|
||||
lin.apply_opt(opt)
|
||||
lin.apply_opts(applied_opts)
|
||||
test_lins = [lin]
|
||||
|
||||
else:
|
||||
|
||||
@@ -1001,7 +1001,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
k = Kernel(r.schedule()[-1].ast)
|
||||
k = hand_coded_optimizations(k)
|
||||
k.apply_opts([Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)])
|
||||
k.linearize()
|
||||
|
||||
stores = [u for u in k.uops if u.op is Ops.STORE]
|
||||
@@ -1306,7 +1306,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
lin = Kernel(sched_copy[-1].ast)
|
||||
lin = hand_coded_optimizations(lin)
|
||||
lin.linearize()
|
||||
assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
|
||||
@@ -1455,8 +1454,6 @@ class TestFloat4(unittest.TestCase):
|
||||
return (len([uop for uop in k.uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]),
|
||||
len([uop for uop in k.uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.half.vec(4)]))
|
||||
|
||||
# TODO: express opts below as auto opts
|
||||
|
||||
def test_float4_basic(self):
|
||||
a = Tensor.empty(2, 8).realize()
|
||||
b = Tensor.empty(2, 8).realize()
|
||||
@@ -1464,7 +1461,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k = hand_coded_optimizations(k)
|
||||
k.apply_opts([Opt(op=OptOps.UPCAST, axis=0, arg=4)])
|
||||
k.linearize()
|
||||
|
||||
assert TestFloat4.count_float4(k) == (2, 1)
|
||||
@@ -1511,7 +1508,6 @@ class TestFloat4(unittest.TestCase):
|
||||
for i in range(len(sizes)):
|
||||
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), excepted_upcast_size[i]) == expected_output[i]
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_unaligned_load(self):
|
||||
a = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
b = Tensor.empty(9).realize().shrink(((1, 9),))
|
||||
@@ -1519,7 +1515,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k = hand_coded_optimizations(k) # implicit trigger float4 dim
|
||||
k.apply_opts([Opt(op=OptOps.UPCAST, axis=0, arg=4)])
|
||||
k.linearize()
|
||||
|
||||
assert TestFloat4.count_float4(k) == (0, 1)
|
||||
@@ -1667,7 +1663,7 @@ class TestFloat4(unittest.TestCase):
|
||||
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
|
||||
]:
|
||||
k = Kernel(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
k.linearize()
|
||||
count = TestFloat4.count_half4(k)
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
@@ -1697,7 +1693,7 @@ class TestFloat4(unittest.TestCase):
|
||||
(4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]),
|
||||
]:
|
||||
k = Kernel(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
k.linearize()
|
||||
count = len([uop for uop in k.uops if uop.op is Ops.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
@@ -1720,7 +1716,7 @@ class TestFloat4(unittest.TestCase):
|
||||
(4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]),
|
||||
]:
|
||||
k = Kernel(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
k.linearize()
|
||||
count = len([uop for uop in k.uops if uop.op is Ops.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)])
|
||||
assert count == expected, f"{count=}, {expected=}"
|
||||
@@ -1838,8 +1834,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
|
||||
if apply_tc:
|
||||
assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered"
|
||||
else:
|
||||
for opt in opts:
|
||||
k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
if expected_color_size is not None:
|
||||
cs = list(zip(k.colors(), k.full_shape))
|
||||
assert cs == expected_color_size, f"expected={expected_color_size} got={cs}"
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)]
|
||||
k = Kernel(ast, opts=Device["METAL"].renderer)
|
||||
k.required_optimizations()
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
Device[Device.DEFAULT].compiler.compile_cached(prg.src)
|
||||
@@ -73,7 +73,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.required_optimizations()
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX"
|
||||
@@ -91,7 +91,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.required_optimizations()
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
if_uops = [u for u in k.uops if u.op is Ops.IF]
|
||||
@@ -157,7 +157,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=3)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 3]
|
||||
@@ -188,7 +188,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 2]
|
||||
@@ -212,7 +212,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
store_idxs = [x.src[1] for x in k.uops if x.op is Ops.STORE]
|
||||
|
||||
@@ -17,13 +17,12 @@ def helper_test_lin(lin: Kernel, opts, failed_platforms, rtol=1e-2, atol=1e-2):
|
||||
if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
|
||||
if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
|
||||
|
||||
for opt in opts:
|
||||
try:
|
||||
lin.apply_opt(opt)
|
||||
except KernelOptError:
|
||||
# it's considered fixed if we invalidated the opts
|
||||
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
|
||||
return
|
||||
try:
|
||||
lin.apply_opts(opts)
|
||||
except KernelOptError:
|
||||
# it's considered fixed if we invalidated the opts
|
||||
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
|
||||
return
|
||||
|
||||
compare_result = compare_linearizer(lin, rtol=rtol, atol=atol)
|
||||
if compare_result[0] in ["PASS", "KernelOptError"]:
|
||||
|
||||
@@ -14,7 +14,7 @@ from tinygrad.shape.view import View
|
||||
|
||||
def _test_overflow(ast, opts):
|
||||
lin = Kernel(ast)
|
||||
for opt in opts: lin.apply_opt(opt)
|
||||
lin.apply_opts(opts)
|
||||
lin.linearize()
|
||||
bufs = bufs_from_lin(lin)
|
||||
print(bufs)
|
||||
|
||||
@@ -40,7 +40,7 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
|
||||
si = out.schedule()[-1]
|
||||
k = Kernel(si.ast, opts=Device[Device.DEFAULT].renderer)
|
||||
#opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128)] #, Opt(op=OptOps.UNROLL, axis=0, arg=4)]
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
if replace_src is not None:
|
||||
old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0]
|
||||
@@ -296,7 +296,7 @@ class TestDSPCache(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
with Context(DEVECTORIZE=0, QUANTIZE=1):
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.apply_opts(opts)
|
||||
prg = k.to_program()
|
||||
#print(prg.src)
|
||||
|
||||
|
||||
@@ -318,8 +318,7 @@ class Kernel:
|
||||
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
|
||||
|
||||
if (tc_opts:=self.tensor_core_opts) is not None:
|
||||
if extra_opts is not None:
|
||||
for opt in extra_opts: self.apply_opt(opt)
|
||||
if extra_opts is not None: self.apply_opts(extra_opts)
|
||||
else:
|
||||
if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
||||
# hand-coded TC opts
|
||||
@@ -430,6 +429,9 @@ class Kernel:
|
||||
if self.simplify_ones() and self.tensor_core_opts:
|
||||
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
||||
|
||||
def apply_opts(self, opts:Sequence[Opt]):
|
||||
for opt in opts: self.apply_opt(opt)
|
||||
|
||||
def required_optimizations(self) -> Kernel:
|
||||
if isinstance(self.membufs[0].dtype, ImageDType):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
||||
|
||||
Reference in New Issue
Block a user