remove const_arg (#9002)

* remove const_arg

* use -m pytest

* remove test_const_arg test, variable arg on CONST does not exist.

* use base in test_const_dtype
This commit is contained in:
qazal
2025-02-10 12:45:11 +01:00
committed by GitHub
parent 0568720a68
commit b17ec42b56
5 changed files with 7 additions and 22 deletions

View File

@@ -423,9 +423,9 @@ jobs:
- name: Test quantize onnx
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py
- name: Test LLVM=1 DEVECTORIZE=0
run: LLVM=1 DEVECTORIZE=0 pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
run: LLVM=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
#- name: Test CLANG=1 DEVECTORIZE=0
# run: CLANG=1 DEVECTORIZE=0 pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
# run: CLANG=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
testwebgpu:
name: Linux (WebGPU)

View File

@@ -61,12 +61,12 @@ class TestTensorUOp(unittest.TestCase):
def test_const_dtype(self):
lb: UOp = Tensor([1], dtype=dtypes.int).lazydata
assert lb.const_like(1).const_arg == 1
assert type(lb.const_like(1).const_arg) is int
assert lb.const_like(1).base.arg == 1
assert type(lb.const_like(1).base.arg) is int
lb: UOp = Tensor([1], dtype=dtypes.float).lazydata
assert lb.const_like(1).const_arg == 1.0
assert type(lb.const_like(1).const_arg) is float
assert lb.const_like(1).base.arg == 1.0
assert type(lb.const_like(1).base.arg) is float
def test_contiguous_alu(self):
a = Tensor.randn(2, 2).realize()

View File

@@ -418,14 +418,6 @@ class TestUOpMethod(unittest.TestCase):
self.assertEqual(const._device, None)
with self.assertRaises(AssertionError): const.device
def test_const_arg(self):
var = UOp.variable("a", 1, 10)
with self.assertRaises(AssertionError): UOp.const(dtypes.int, var).const_arg
const = UOp.const(dtypes.int, 1)
self.assertEqual(const.const_arg, 1)
tensor_const = UOp.metaop(Ops.CONST, (), dtypes.int, Device.DEFAULT, 1)
self.assertEqual(tensor_const.const_arg, 1)
class TestUOpStr(unittest.TestCase):
def test_uop_str(self):
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)

View File

@@ -48,7 +48,7 @@ sym = symbolic_simple+PatternMatcher([
# reduce on stride 0 is collapsed
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.arg)),
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),

View File

@@ -343,13 +343,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert self.op in GroupOp.Buffer, f"st_arg called on {self.op}"
return unwrap(self.st)
@property
def const_arg(self) -> ConstType:
match self.base.op:
case Ops.CONST: ret = self.base.arg
case op: raise AssertionError(f"const_arg called on {op}")
assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
return ret
@property
def axis_arg(self) -> tuple[int, ...]:
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]