use ConstType in various const function type hint (#4074)

This commit is contained in:
chenyu
2024-04-04 20:32:07 -04:00
committed by GitHub
parent c1cffed1df
commit 5e6e6c9a67
4 changed files with 14 additions and 15 deletions

View File

@@ -3,7 +3,7 @@ from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, F
import itertools, math, functools
from collections import defaultdict
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker
@@ -39,12 +39,12 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
class Linearizer(Kernel):
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype:DType=dtypes.int32):
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
return self.uops.add(UOps.ALU, dtype, (a, render_b), op)
# NOTE: the consts have to be cached for deduping of downstream uops to work
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp:
def const(self, b:ConstType, dtype:DType=dtypes.int32, insert_before=None) -> UOp:
return self.uops.add(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val

View File

@@ -4,7 +4,7 @@ import functools, hashlib, math, operator, ctypes
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import prod, dedup
from tinygrad.dtype import dtypes, DType
from tinygrad.dtype import dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.buffer import Buffer
@@ -32,7 +32,7 @@ class MemBuffer:
@dataclass(frozen=True)
class ConstBuffer:
val: Union[int, float]
val: ConstType
dtype: DType
st: ShapeTracker

View File

@@ -3,7 +3,7 @@ import functools, struct
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType, INVERSE_DTYPES_DICT
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
def render_val(x, dtype):
@@ -46,7 +46,7 @@ class AssemblyLanguage(NamedTuple):
types: Dict[DType, str] = INVERSE_DTYPES_DICT
supports_half: List[Op] = []
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError()
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: raise NotImplementedError()
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
@@ -112,7 +112,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
return r_label[u]
def const(x:Union[float,int,bool], dtype, mov=False):
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in lang.const_requires_mov:
kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype]))))
return out
@@ -249,7 +249,7 @@ class PTXLanguage(AssemblyLanguage):
const_requires_mov = [dtypes.half, dtypes.bool]
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
val = render_val(x, dtype)
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val

View File

@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import strip_parens, getenv, prod
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph
class CStyleLanguage(NamedTuple):
@@ -41,13 +41,12 @@ class CStyleLanguage(NamedTuple):
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int,bool], var_dtype:DType) -> str:
def render_const(self, x:ConstType, dtype:DType) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
elif var_dtype == dtypes.float64: val = f"{x}"
else: val = f"{x}f" if dtypes.is_float(var_dtype) else f"{x}" if dtypes.is_int(var_dtype) else f"{x}".lower()
return (self.render_cast([val]*var_dtype.count, var_dtype)
if var_dtype.count > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
elif dtype == dtypes.float64: val = f"{x}"
else: val = f"{x}f" if dtypes.is_float(dtype) else f"{x}" if dtypes.is_int(dtype) else f"{x}".lower()
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: