mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
use ConstType in various const function type hint (#4074)
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, F
|
||||
import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -39,12 +39,12 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
|
||||
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
|
||||
|
||||
class Linearizer(Kernel):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype:DType=dtypes.int32):
|
||||
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
return self.uops.add(UOps.ALU, dtype, (a, render_b), op)
|
||||
|
||||
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
||||
def const(self, b:Union[int,float], dtype=dtypes.int32, insert_before=None) -> UOp:
|
||||
def const(self, b:ConstType, dtype:DType=dtypes.int32, insert_before=None) -> UOp:
|
||||
return self.uops.add(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
|
||||
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uops.add(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools, hashlib, math, operator, ctypes
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import prod, dedup
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.dtype import dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.buffer import Buffer
|
||||
@@ -32,7 +32,7 @@ class MemBuffer:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstBuffer:
|
||||
val: Union[int, float]
|
||||
val: ConstType
|
||||
dtype: DType
|
||||
st: ShapeTracker
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools, struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType, INVERSE_DTYPES_DICT
|
||||
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
|
||||
|
||||
def render_val(x, dtype):
|
||||
@@ -46,7 +46,7 @@ class AssemblyLanguage(NamedTuple):
|
||||
types: Dict[DType, str] = INVERSE_DTYPES_DICT
|
||||
supports_half: List[Op] = []
|
||||
|
||||
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError()
|
||||
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: raise NotImplementedError()
|
||||
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
|
||||
|
||||
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
|
||||
@@ -112,7 +112,7 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
|
||||
return r_label[u]
|
||||
|
||||
def const(x:Union[float,int,bool], dtype, mov=False):
|
||||
def const(x:ConstType, dtype:DType, mov=False):
|
||||
if mov or dtype in lang.const_requires_mov:
|
||||
kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype]))))
|
||||
return out
|
||||
@@ -249,7 +249,7 @@ class PTXLanguage(AssemblyLanguage):
|
||||
|
||||
const_requires_mov = [dtypes.half, dtypes.bool]
|
||||
|
||||
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
|
||||
def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]:
|
||||
val = render_val(x, dtype)
|
||||
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
|
||||
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import strip_parens, getenv, prod
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
@@ -41,13 +41,12 @@ class CStyleLanguage(NamedTuple):
|
||||
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
|
||||
|
||||
# returns a str expression of the const with the given type
|
||||
def render_const(self, x:Union[float,int,bool], var_dtype:DType) -> str:
|
||||
def render_const(self, x:ConstType, dtype:DType) -> str:
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
elif var_dtype == dtypes.float64: val = f"{x}"
|
||||
else: val = f"{x}f" if dtypes.is_float(var_dtype) else f"{x}" if dtypes.is_int(var_dtype) else f"{x}".lower()
|
||||
return (self.render_cast([val]*var_dtype.count, var_dtype)
|
||||
if var_dtype.count > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
elif dtype == dtypes.float64: val = f"{x}"
|
||||
else: val = f"{x}f" if dtypes.is_float(dtype) else f"{x}" if dtypes.is_int(dtype) else f"{x}".lower()
|
||||
return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
|
||||
Reference in New Issue
Block a user