mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
[TYPED=1] cvar should allow dtype as a tuple (#11770)
* cvar dtype:DType|tuple[DType, ...]|None=None * fmt * add a test * list typeguard as a dep for CI * extra step to install mypy * fix venv * ci fixes * mv typeguard to testing install group * simpler TYPED=1 test * add typeguard to lint group
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -343,6 +343,8 @@ jobs:
|
||||
run: |
|
||||
python -m mypy --strict-equality --lineprecision-report .
|
||||
cat lineprecision.txt
|
||||
- name: Run TYPED=1
|
||||
run: TYPED=1 python -c "import tinygrad"
|
||||
|
||||
unittest:
|
||||
name: Unit Tests
|
||||
|
||||
1
setup.py
1
setup.py
@@ -64,6 +64,7 @@ setup(name='tinygrad',
|
||||
"pre-commit",
|
||||
"ruff",
|
||||
"numpy",
|
||||
"typeguard",
|
||||
],
|
||||
#'mlperf': ["mlperf-logging @ git+https://github.com/mlperf/logging.git@5.0.0-rc3"],
|
||||
'testing_minimal': testing_minimal,
|
||||
|
||||
@@ -681,7 +681,8 @@ class UPat(MathTrait):
|
||||
def var(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None): return UPat(dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
@functools.cache
|
||||
def cvar(name:str|None=None, dtype:DType|None=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
||||
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True):
|
||||
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
||||
@staticmethod
|
||||
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user