mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
apply flake8 E203 rule (#684)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
- name: Lint with pylint
|
||||
run: python -m pylint --disable=all -e W0311 --jobs=0 --indent-string=' ' **/*.py
|
||||
- name: Lint with flake8
|
||||
run: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
|
||||
run: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E203,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
|
||||
- name: Lint tinygrad with pylint
|
||||
run: pylint tinygrad/
|
||||
- name: Run mypy
|
||||
|
||||
@@ -9,7 +9,7 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: flake8
|
||||
name: flake8
|
||||
entry: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
|
||||
entry: flake8 tinygrad/ --indent-size=2 --select=F,E112,E113,E203,E304,E502,E702,E703,E71,E72,E731,W191,W6 --statistics -j4
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -16,7 +16,7 @@ class Token:
|
||||
def __init__(self, tok:str, typ:Types, ptr:bool=False):
|
||||
assert isinstance(tok, str)
|
||||
self.tok, self.typ, self.ptr = tok, typ, ptr
|
||||
self.axis : List[Tuple[int, int, bool]] = []
|
||||
self.axis: List[Tuple[int, int, bool]] = []
|
||||
def array(self, length, stride, reduce): self.axis.append((length, stride, reduce))
|
||||
def size(self): return prod([x[0] for x in self.axis])
|
||||
def offsets(self): return [sum(t) for t in itertools.product(*[[y*x[1] for y in range(x[0])] for x in self.axis[::-1]])] if len(self.axis) else [0]
|
||||
@@ -73,7 +73,7 @@ class ASTKernel:
|
||||
self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else []
|
||||
|
||||
self.buftokens = [Token(f"data{i}", Types.FLOAT, ptr=True) for i in range(len(self.bufs))]
|
||||
self.group_for_reduce : List[int] = []
|
||||
self.group_for_reduce: List[int] = []
|
||||
|
||||
# check valid AST kernel
|
||||
assert all_same([x.shape for x in self.earlybufs]), "all earlybufs must have the same shape"
|
||||
@@ -81,10 +81,10 @@ class ASTKernel:
|
||||
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
|
||||
|
||||
# get full shape buf index (earlybufs if there are any, otherwise output)
|
||||
self.full_buf_index : int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0
|
||||
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0
|
||||
|
||||
# process
|
||||
self.sts : List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel
|
||||
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs] # create new shapetrackers inside this kernel
|
||||
|
||||
# move all reduce axes to the end
|
||||
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
|
||||
|
||||
@@ -16,16 +16,16 @@ VALIDHACKS = getenv("VALIDHACKS", 0) # TODO: remove the need for this
|
||||
NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the tests to pass
|
||||
|
||||
class GPULanguage(NamedTuple):
|
||||
kernel_prefix : str = ""
|
||||
buffer_prefix : str = ""
|
||||
buffer_suffix : str = ""
|
||||
smem_prefix : str = ""
|
||||
barrier : str = ""
|
||||
gid : List[str] = []
|
||||
lid : List[str] = []
|
||||
extra_args : List[str] = []
|
||||
float4 : Optional[str] = None
|
||||
half_prekernel : Optional[str] = None
|
||||
kernel_prefix: str = ""
|
||||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
smem_prefix: str = ""
|
||||
barrier: str = ""
|
||||
gid: List[str] = []
|
||||
lid: List[str] = []
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
half_prekernel: Optional[str] = None
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
|
||||
idy = (idxy//(4*base_shape[1]))
|
||||
@@ -48,13 +48,13 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F
|
||||
return idx, idy
|
||||
|
||||
class GPUCodegen(ASTKernel):
|
||||
lang : ClassVar[GPULanguage] = GPULanguage()
|
||||
lang: ClassVar[GPULanguage] = GPULanguage()
|
||||
|
||||
# for renaming
|
||||
kernel_cnt : Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
kernel_name_cache : Final[Dict[str, str]] = {}
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
kernel_name_cache: Final[Dict[str, str]] = {}
|
||||
|
||||
code_for_op : Final[Dict[Op, str]] = {
|
||||
code_for_op: Final[Dict[Op, str]] = {
|
||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "(1.0f-A)",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||
@@ -62,7 +62,7 @@ class GPUCodegen(ASTKernel):
|
||||
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
BinaryOps.MAX: "max(A,B)", ReduceOps.SUM: "A+=B", ReduceOps.MAX: "A=max(A,B)"
|
||||
}
|
||||
start_for_op : Final[Dict[Op, str]] = {ReduceOps.SUM: "0.0f", ReduceOps.MAX: "-INFINITY"}
|
||||
start_for_op: Final[Dict[Op, str]] = {ReduceOps.SUM: "0.0f", ReduceOps.MAX: "-INFINITY"}
|
||||
|
||||
def group_float4(self, grp:List[Token]) -> Token:
|
||||
if all(g.tok.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.tok.split(".")[0] for g in grp]): return Token(grp[0].tok.split(".")[0], Types.FLOAT4)
|
||||
@@ -142,7 +142,7 @@ class GPUCodegen(ASTKernel):
|
||||
def ast_parse(self, x, acc:List[Token], do_reduce=False) -> List[Token]:
|
||||
if not isinstance(x, LazyOp): return self.load(self.bufs.index(x), "mid" if x is None else None) # hack for local
|
||||
if isinstance(x.op, ReduceOps) and not do_reduce: return acc
|
||||
values : List[List[Token]] = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
|
||||
values: List[List[Token]] = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
|
||||
code = GPUCodegen.code_for_op[x.op] # TODO: replace this with a function
|
||||
if len(values) == 2:
|
||||
assert len(values[0]) == len(values[1]) and values[0][0].typ == values[1][0].typ, f"values mismatch {values}"
|
||||
@@ -256,16 +256,16 @@ class GPUCodegen(ASTKernel):
|
||||
self.sts.append(st)
|
||||
self.buftokens.append(buftoken)
|
||||
|
||||
self.output_shape : Tuple[int, ...] = self.sts[0].shape[:self.first_reduce] + tuple(self.group_for_reduce)
|
||||
self.output_shape: Tuple[int, ...] = self.sts[0].shape[:self.first_reduce] + tuple(self.group_for_reduce)
|
||||
assert self.full_shape[:len(self.output_shape)] == self.output_shape, f"output shape mismatch : {self.full_shape[:len(self.output_shape)]} != {self.output_shape}"
|
||||
if DEBUG >= 4:
|
||||
print("output shape", self.output_shape)
|
||||
self.printbufs("new:", DEBUG>=5)
|
||||
|
||||
self.bufs_to_delete : Set[int] = set()
|
||||
self.loaded_keys : Dict[Tuple[int,int], Token] = {}
|
||||
self.prekernel : Set[str] = set()
|
||||
self.kernel : List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
|
||||
self.bufs_to_delete: Set[int] = set()
|
||||
self.loaded_keys: Dict[Tuple[int,int], Token] = {}
|
||||
self.prekernel: Set[str] = set()
|
||||
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
|
||||
|
||||
if self.lang.half_prekernel: self.prekernel.add(self.lang.half_prekernel+"\n")
|
||||
|
||||
@@ -284,7 +284,7 @@ class GPUCodegen(ASTKernel):
|
||||
if DEBUG >= 4: print(f"replaced output shape with {self.output_shape}")
|
||||
|
||||
# early ast
|
||||
accumulators : List[Token] = []
|
||||
accumulators: List[Token] = []
|
||||
if self.reduceop is not None:
|
||||
accumulators = self.get_accumulators()
|
||||
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.full_shape[i]}; idx{i}++) {{\n" for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len)]
|
||||
|
||||
@@ -20,7 +20,7 @@ render_llvm = {
|
||||
}
|
||||
|
||||
class LLVMCodegen(ASTKernel):
|
||||
op_lookup : ClassVar = {
|
||||
op_lookup: ClassVar = {
|
||||
UnaryOps.NOOP: lambda builder,x: x,
|
||||
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
|
||||
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
@@ -34,7 +34,7 @@ class LLVMCodegen(ASTKernel):
|
||||
BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()),
|
||||
BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',))
|
||||
}
|
||||
start_for_op : ClassVar = {
|
||||
start_for_op: ClassVar = {
|
||||
ReduceOps.SUM: ir.Constant(ir.FloatType(), 0),
|
||||
ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf)
|
||||
}
|
||||
@@ -44,7 +44,7 @@ class LLVMCodegen(ASTKernel):
|
||||
if DEBUG >= 3: self.printbufs("old:", DEBUG>=4)
|
||||
|
||||
# this stuff can't be hand coded
|
||||
kernel_output_axis : List[int] = []
|
||||
kernel_output_axis: List[int] = []
|
||||
"""
|
||||
CACHE_DIM = 32
|
||||
if len(k.shapes[0]) == 2:
|
||||
|
||||
@@ -13,7 +13,7 @@ GRAPH, PRUNEGRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("PRUNEGRAPH", 0), gete
|
||||
# **** debugging and graphing ****
|
||||
|
||||
G = nx.DiGraph() if nx is not None else None
|
||||
cnts : Dict[OpType, int] = defaultdict(int)
|
||||
cnts: Dict[OpType, int] = defaultdict(int)
|
||||
if GRAPH:
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
@@ -32,16 +32,16 @@ def nm(x):
|
||||
node_count += 1
|
||||
return x.node_id
|
||||
|
||||
def get_sop(op : List[Op]):
|
||||
def get_sop(op: List[Op]):
|
||||
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
||||
return str(len(op))
|
||||
|
||||
def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : Optional[bool] = None):
|
||||
def log_op(ret: DeviceBuffer, ast: LazyOp, show_graph: Optional[bool] = None):
|
||||
if show_graph is None: show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
op : List[Op] = [x.op for x in get_lazyops(ast)]
|
||||
inp : List[DeviceBuffer] = get_buffers(ast)
|
||||
op: List[Op] = [x.op for x in get_lazyops(ast)]
|
||||
inp: List[DeviceBuffer] = get_buffers(ast)
|
||||
if len(inp) == 1 and inp[0] == ret:
|
||||
if show_graph and nm(ret) in G.nodes: G.nodes[nm(ret)]['style'] += ', bold'
|
||||
return # don't log self loops
|
||||
|
||||
@@ -21,9 +21,9 @@ DEBUG, IMAGE = getenv("DEBUG", 0), getenv("IMAGE", 0)
|
||||
# **** tinygrad now supports dtypes! *****
|
||||
|
||||
class DType(NamedTuple):
|
||||
itemsize : int
|
||||
name : str
|
||||
np : type # TODO: someday this will be removed with the "remove numpy" project
|
||||
itemsize: int
|
||||
name: str
|
||||
np: type # TODO: someday this will be removed with the "remove numpy" project
|
||||
def __repr__(self): return f"dtypes.{self.name}"
|
||||
|
||||
class LazyNumpyArray:
|
||||
@@ -34,7 +34,7 @@ class LazyNumpyArray:
|
||||
def astype(self, typ): return self
|
||||
|
||||
class dtypes:
|
||||
float16 : Final[DType] = DType(2, "half", np.float16)
|
||||
float32 : Final[DType] = DType(4, "float", np.float32)
|
||||
float16: Final[DType] = DType(2, "half", np.float16)
|
||||
float32: Final[DType] = DType(4, "float", np.float32)
|
||||
@staticmethod
|
||||
def from_np(x:Union[LazyNumpyArray, np.ndarray]) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x.dtype)]
|
||||
def from_np(x:Union[LazyNumpyArray, np.ndarray]) -> DType: return {np.dtype(np.float16): dtypes.float16, np.dtype(np.float32): dtypes.float32}[np.dtype(x.dtype)]
|
||||
|
||||
@@ -8,16 +8,16 @@ from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn : Callable = fxn
|
||||
self.cnt : int = 0
|
||||
self.jit_cache : List[Tuple[Callable, Any]] = [] # TODO: Any should be List[RawBuffer], but this fails
|
||||
self.ret : Any = None
|
||||
self.input_replace : Dict[Tuple[int, int], Union[int, str]]= {}
|
||||
self.fxn: Callable = fxn
|
||||
self.cnt: int = 0
|
||||
self.jit_cache: List[Tuple[Callable, Any]] = [] # TODO: Any should be List[RawBuffer], but this fails
|
||||
self.ret: Any = None
|
||||
self.input_replace: Dict[Tuple[int, int], Union[int, str]]= {}
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
|
||||
input_rawbuffers : Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(CompiledBuffer, v.realize().lazydata.realized).raw() for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(CompiledBuffer, v.realize().lazydata.realized).raw() for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
if self.cnt >= 2:
|
||||
for (j,i),idx in self.input_replace.items(): self.jit_cache[j][1][i] = input_rawbuffers[idx]
|
||||
|
||||
@@ -16,7 +16,7 @@ LAZY = getenv("LAZY", 1)
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._buffers = {y.upper():y for y in [os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")]}
|
||||
self.DEFAULT : str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, "CPU")
|
||||
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, "CPU")
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __getitem__(self, x:str) -> Type[DeviceBuffer]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{self._buffers[x]}'), inspect.isclass) if (cname.lower() == self._buffers[x] + "buffer")][0]
|
||||
Device = _Device()
|
||||
@@ -37,10 +37,10 @@ def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
|
||||
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
||||
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
|
||||
real_srcs: Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
|
||||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
||||
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape : Tuple[int, ...] = self.shape
|
||||
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape: Tuple[int, ...] = self.shape
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
|
||||
if psrcs[0][1].optype == ReduceOps:
|
||||
top = _ast_reduceops(psrcs[0][1])
|
||||
@@ -75,7 +75,7 @@ def support_weakref(x): return x
|
||||
@support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class
|
||||
class LazyBuffer:
|
||||
__deletable__ = ('op',)
|
||||
lazycache : ClassVar[WeakValueDictionary[Tuple[str, DType, OpType, LazyOp], LazyBuffer]] = WeakValueDictionary()
|
||||
lazycache: ClassVar[WeakValueDictionary[Tuple[str, DType, OpType, LazyOp], LazyBuffer]] = WeakValueDictionary()
|
||||
def __new__(cls, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp, dtype:DType):
|
||||
# fromcpu aren't cached
|
||||
if optype == LoadOps and op.op == LoadOps.FROMCPU:
|
||||
@@ -91,11 +91,11 @@ class LazyBuffer:
|
||||
return # cache hit, we return and don't reinit
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape, self.optype, self.op, self.dtype = self.st.shape, optype, op, dtype
|
||||
self.realized : Optional[DeviceBuffer] = None
|
||||
self.output_buffer : Optional[DeviceBuffer] = None
|
||||
self.realized: Optional[DeviceBuffer] = None
|
||||
self.output_buffer: Optional[DeviceBuffer] = None
|
||||
self.device, self.dbuffer = device, Device[device]
|
||||
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
|
||||
self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
|
||||
self.children: weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
|
||||
# NOTE: op should be read only after construction of LazyBuffer
|
||||
for x in get_buffers(op): x.children.add(self)
|
||||
if not LAZY: self.realize()
|
||||
@@ -203,7 +203,7 @@ class LazyBuffer:
|
||||
# move permutes before reshapes if we can
|
||||
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
new_arg : List[int] = functools.reduce(lambda r, x: r + shape_idx_groups[x], arg, [])
|
||||
new_arg: List[int] = functools.reduce(lambda r, x: r + shape_idx_groups[x], arg, [])
|
||||
self.op.src[0].children.discard(self) # this changes nothing?
|
||||
return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(new_arg)) \
|
||||
.movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape)
|
||||
|
||||
@@ -59,8 +59,8 @@ class Linear:
|
||||
class GroupNorm:
|
||||
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
|
||||
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
||||
self.weight : Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
||||
self.bias : Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
||||
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
# reshape for layernorm to work as group norm
|
||||
|
||||
@@ -3,13 +3,13 @@ from typing import List
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, params : List[Tensor]):
|
||||
def __init__(self, params: List[Tensor]):
|
||||
# if it's None, but being put into an optimizer, set it to True
|
||||
for x in params:
|
||||
if x.requires_grad is None: x.requires_grad = True
|
||||
|
||||
self.params : List[Tensor] = [x for x in params if x.requires_grad]
|
||||
self.buffers : List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized
|
||||
self.params: List[Tensor] = [x for x in params if x.requires_grad]
|
||||
self.buffers: List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized
|
||||
|
||||
# TODO: this probably shouldn't change the gradients, just the ones used by the optimizer
|
||||
def clipnorm(self, amount=1):
|
||||
@@ -27,7 +27,7 @@ class Optimizer:
|
||||
p.realize()
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params : List[Tensor], lr=0.001, momentum=0, nesterov=False):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, nesterov=False):
|
||||
super().__init__(params)
|
||||
self.lr, self.momentum, self.nesterov = lr, momentum, nesterov
|
||||
self.b = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params] if self.momentum else []
|
||||
@@ -44,7 +44,7 @@ class SGD(Optimizer):
|
||||
self.realize(self.b)
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
def __init__(self, params : List[Tensor], lr=0.001, decay=0.9, eps=1e-8):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, decay=0.9, eps=1e-8):
|
||||
super().__init__(params)
|
||||
self.lr, self.decay, self.eps = lr, decay, eps
|
||||
|
||||
@@ -58,7 +58,7 @@ class RMSprop(Optimizer):
|
||||
self.realize(self.v)
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, params : List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||
super().__init__(params)
|
||||
# NOTE: self.t is a tensor so Adam can be jitted
|
||||
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0], requires_grad=False).realize()
|
||||
@@ -77,7 +77,7 @@ class Adam(Optimizer):
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
def get_parameters(obj) -> List[Tensor]:
|
||||
parameters : List[Tensor] = []
|
||||
parameters: List[Tensor] = []
|
||||
if isinstance(obj, Tensor):
|
||||
parameters.append(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
|
||||
@@ -39,9 +39,9 @@ class Copyable:
|
||||
|
||||
class RawBuffer(Copyable): # pylint: disable=abstract-method
|
||||
def __init__(self, size:int, dtype:DType):
|
||||
self.size : int = size
|
||||
self.dtype : DType = dtype
|
||||
self._memsz : int = size*dtype.itemsize
|
||||
self.size: int = size
|
||||
self.dtype: DType = dtype
|
||||
self._memsz: int = size*dtype.itemsize
|
||||
GlobalCounters.mem_used += self._memsz
|
||||
def __del__(self): GlobalCounters.mem_used -= self._memsz
|
||||
|
||||
@@ -81,7 +81,7 @@ class GenericShape:
|
||||
def consume_flops(self):
|
||||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
shape_fxn_for_op : Dict[Op, Callable] = {
|
||||
shape_fxn_for_op: Dict[Op, Callable] = {
|
||||
**{op:lambda self: GenericShape(self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps},
|
||||
**{op:lambda self,y: GenericShape(self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||
**{op:lambda self,new_shape: GenericShape(new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps},
|
||||
@@ -90,7 +90,7 @@ def get_lazyop_info(ast:LazyOp): return InterpretedBuffer.exec_ast(map_buffers({
|
||||
|
||||
# used in CPUBuffer and TorchBuffer
|
||||
class InterpretedBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
fxn_for_op : ClassVar = shape_fxn_for_op
|
||||
fxn_for_op: ClassVar = shape_fxn_for_op
|
||||
def __init__(self, lbuf:Any): self._buf, self.shape, self.dtype = lbuf, tuple(lbuf.shape), self.to_tinygrad_dtype(lbuf) if hasattr(self, 'to_tinygrad_dtype') else lbuf.dtype
|
||||
def contiguous(self): return type(self).exec_ast(LazyOp(op=UnaryOps.NOOP, src=(self,)))
|
||||
def movement_op(self, op:MovementOps, arg=None): return type(self)(self.fxn_for_op[op](self._buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self._buf, op.name.lower())(arg))
|
||||
@@ -162,15 +162,15 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
self.shape = self.st.shape
|
||||
self.dtype = dtype
|
||||
assert hostbuf is None or hostbuf.dtype == dtype, f"hostbuf dtype {hostbuf.dtype} != {dtype}"
|
||||
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
|
||||
self._base_shape: Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
|
||||
self._buf = hostbuf._buf if hostbuf is not None else None
|
||||
self._backing : Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
|
||||
self._backing: Optional[np.ndarray] = hostbuf._backing if hostbuf is not None else backing
|
||||
assert self._backing is None or dtypes.from_np(self._backing) == dtype, f"backing dtype {dtypes.from_np(self._backing)} != {dtype}"
|
||||
if (self._backing is not None and self._backing.shape != (1,)) or force_create: self.raw()
|
||||
|
||||
def __repr__(self): return f"{type(self).__name__}(shape={self.st}, hostbuf={type(self).__name__}(shape={self._base_shape}" + (f", backing=np.array({self._backing}, dtype=np.{self.dtype.np.__name__}), dtype={self.dtype}), dtype={self.dtype})" if self._backing is not None else f", force_create=True, dtype={self.dtype}), dtype={self.dtype})")
|
||||
|
||||
raw_buffer_type : ClassVar[Type[RawBuffer]]
|
||||
raw_buffer_type: ClassVar[Type[RawBuffer]]
|
||||
@classmethod
|
||||
def create_raw_buffer(cls, shape:Tuple[int, ...], backing:Optional[np.ndarray], dtype:DType) -> RawBuffer:
|
||||
assert backing is None or prod(shape) == prod(backing.shape), "backing has the wrong shape"
|
||||
@@ -191,10 +191,10 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
if DEBUG >= 3: print(f"**** copy out {self.shape}")
|
||||
return self.contiguous().raw().toCPU().reshape(self.shape)
|
||||
|
||||
codegen_type : ClassVar[Any]
|
||||
runtime_type : ClassVar[Type]
|
||||
codegen_type: ClassVar[Any]
|
||||
runtime_type: ClassVar[Type]
|
||||
|
||||
method_cache : Final[Dict[str, ASTRunner]] = {}
|
||||
method_cache: Final[Dict[str, ASTRunner]] = {}
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[CompiledBuffer]=None):
|
||||
k = cls.codegen_type(ast, output_buffer)
|
||||
@@ -215,11 +215,11 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
||||
def movement_op(self, op:MovementOps, arg): return type(self)(ShapeTracker(self.st).movement_op(op, arg), hostbuf=self, dtype=self.dtype)
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops : ClassVar[int] = 0
|
||||
global_mem : ClassVar[int] = 0
|
||||
time_sum_s : ClassVar[float] = 0.0
|
||||
kernel_count : ClassVar[int] = 0
|
||||
mem_used : ClassVar[int] = 0 # NOTE: this is not reset
|
||||
cache : ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
|
||||
global_ops: ClassVar[int] = 0
|
||||
global_mem: ClassVar[int] = 0
|
||||
time_sum_s: ClassVar[float] = 0.0
|
||||
kernel_count: ClassVar[int] = 0
|
||||
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
||||
cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
|
||||
@staticmethod
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
|
||||
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.ops import CompiledBuffer, RawBufferMapped
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype : DType):
|
||||
def __init__(self, size, dtype: DType):
|
||||
super().__init__(size, dtype)
|
||||
self._buf = ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16}[dtype] * size)()
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
@@ -8,7 +8,7 @@ def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple
|
||||
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
|
||||
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
|
||||
|
||||
base_fxn_for_op : Dict[Op, Callable] = {
|
||||
base_fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.NEG: lambda x: -x, UnaryOps.NOT: lambda x: (1.0 - x),
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
@@ -26,7 +26,7 @@ def einsum_mulacc(einsum, get_strides, expand):
|
||||
return expand(ret.reshape([(1 if i not in a_axes and i not in b_axes else s) for i,s in enumerate(new_shape)]), new_shape)
|
||||
return mulacc
|
||||
|
||||
numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: np.ascontiguousarray, UnaryOps.EXP: np.exp, UnaryOps.LOG: np.log,
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
|
||||
MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to,
|
||||
@@ -35,7 +35,7 @@ numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
}}
|
||||
|
||||
class CPUBuffer(InterpretedBuffer):
|
||||
fxn_for_op : ClassVar = numpy_fxn_for_op
|
||||
fxn_for_op: ClassVar = numpy_fxn_for_op
|
||||
to_tinygrad_dtype = staticmethod(dtypes.from_np)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -14,7 +14,7 @@ FLOAT16 = getenv("FLOAT16", 0)
|
||||
class _CL:
|
||||
@functools.cached_property
|
||||
def cl_ctx(self) -> cl.Context:
|
||||
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
devices: List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
if len(devices) == 0: devices = sum([x.get_devices(device_type=cl.device_type.CPU) for x in cl.get_platforms()], []) # settle for CPU
|
||||
if len(devices) > 1 or DEBUG >= 1: print(f"using {devices[getenv('CL_DEVICE', 0)]}")
|
||||
return cl.Context(devices=[devices[getenv("CL_DEVICE", 0)]])
|
||||
@@ -32,7 +32,7 @@ class CLBuffer(RawBufferCopyInOut):
|
||||
def copyout(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, x, self._cl, is_blocking=True)
|
||||
|
||||
class CLImage(RawBuffer): # pylint: disable=abstract-method
|
||||
IMAGE : Final = True
|
||||
IMAGE: Final = True
|
||||
def __init__(self, shape, dtype=dtypes.float16 if getenv("FLOAT16") else dtypes.float32): # pylint: disable=super-init-not-called
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {dtypes.float16: cl.channel_type.HALF_FLOAT, dtypes.float32: cl.channel_type.FLOAT}[dtype])
|
||||
self.size, self.dtype, self._cl = shape, dtype, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(shape[1], shape[0]))
|
||||
|
||||
@@ -9,9 +9,9 @@ from tinygrad.codegen.llvm import LLVMCodegen
|
||||
import llvmlite.binding as llvm # type: ignore
|
||||
|
||||
class LLVM:
|
||||
target_machine : ClassVar[llvm.targets.TargetMachine] = None
|
||||
engine : ClassVar[llvm.executionengine.ExecutionEngine] = None
|
||||
optimizer : ClassVar[llvm.passmanagers.ModulePassManager] = None
|
||||
target_machine: ClassVar[llvm.targets.TargetMachine] = None
|
||||
engine: ClassVar[llvm.executionengine.ExecutionEngine] = None
|
||||
optimizer: ClassVar[llvm.passmanagers.ModulePassManager] = None
|
||||
|
||||
def __init__(self):
|
||||
if LLVM.engine is not None: return
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.ops import CompiledBuffer, RawBufferMapped
|
||||
METAL_XCODE = getenv("METAL_XCODE")
|
||||
|
||||
class _METAL:
|
||||
mtl_buffers_in_flight : Final[List[Any]] = []
|
||||
mtl_buffers_in_flight: Final[List[Any]] = []
|
||||
@functools.cached_property
|
||||
def device(self) -> Any:
|
||||
return Metal.MTLCreateSystemDefaultDevice()
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, FusedOps, Interpreted
|
||||
from tinygrad.helpers import getenv, dtypes
|
||||
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
|
||||
torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
|
||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||
@@ -14,7 +14,7 @@ torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
class TorchBuffer(InterpretedBuffer):
|
||||
fxn_for_op : ClassVar = torch_fxn_for_op
|
||||
fxn_for_op: ClassVar = torch_fxn_for_op
|
||||
to_tinygrad_dtype = staticmethod(lambda lbuf: {torch.float16: dtypes.float16, torch.float32: dtypes.float32}[lbuf.dtype])
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -27,7 +27,7 @@ class View:
|
||||
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0):
|
||||
self.shape, self.strides, self.offset = shape, tuple(stride if shp != 1 else 0 for stride,shp in zip(strides, shape)), offset
|
||||
self.shape_strides = to_shape_strides(self.shape, self.strides)
|
||||
self.contiguous : bool = self.offset == 0 and is_contiguous(self.shape, self.strides)
|
||||
self.contiguous: bool = self.offset == 0 and is_contiguous(self.shape, self.strides)
|
||||
|
||||
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset})"
|
||||
|
||||
@@ -47,7 +47,7 @@ class View:
|
||||
class ZeroView:
|
||||
def __init__(self, old_shape:Tuple[int, ...], arg):
|
||||
self.old_shape, self.arg = old_shape, arg
|
||||
self.shape : Tuple[int, ...] = tuple([y-x for x,y in self.arg])
|
||||
self.shape: Tuple[int, ...] = tuple([y-x for x,y in self.arg])
|
||||
# fake properties
|
||||
self.strides, self.contiguous, self.offset = strides_for_shape(self.shape), False, 0
|
||||
|
||||
@@ -98,7 +98,7 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
|
||||
class ShapeTracker:
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[ViewTypes]]=None):
|
||||
self.views : List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
|
||||
self.views: List[ViewTypes] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.shape, self.views[:])
|
||||
|
||||
@@ -145,11 +145,11 @@ class ShapeTracker:
|
||||
|
||||
# *** under this line are the movement ops ***
|
||||
|
||||
def __unsafe_resize(self, arg : Tuple[Tuple[int, int], ...]):
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
offset = sum([self.strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
self.views[-1] = View(tuple(y-x for x,y in arg), self.strides, self.offset+offset)
|
||||
|
||||
def _pad(self, arg : Tuple[Tuple[int, int], ...]):
|
||||
def _pad(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
if all(b==0 and e==0 for b,e in arg): return self # ZeroView is expensive if we don't need it
|
||||
zvarg = tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg))
|
||||
@@ -158,17 +158,17 @@ class ShapeTracker:
|
||||
# if we add a ZeroView, we add another (stock) view also for modding
|
||||
self.views += [zeroview, View(self.shape, strides_for_shape(self.shape))]
|
||||
|
||||
def _shrink(self, arg : Tuple[Tuple[int, int], ...]):
|
||||
def _shrink(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
|
||||
self.__unsafe_resize(arg)
|
||||
|
||||
def _expand(self, new_shape : Tuple[int, ...]):
|
||||
def _expand(self, new_shape: Tuple[int, ...]):
|
||||
assert all(isinstance(x, int) for x in new_shape), f"non ints for expand in {new_shape}"
|
||||
assert all(x == y or x == 1 for x,y in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
strides : Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape)))
|
||||
strides: Tuple[int, ...] = tuple(s if x == y else 0 for s,(x,y) in zip(self.strides, zip(self.shape, new_shape)))
|
||||
self.views[-1] = View(new_shape, strides, self.offset)
|
||||
|
||||
def _reshape(self, new_shape : Tuple[int, ...]):
|
||||
def _reshape(self, new_shape: Tuple[int, ...]):
|
||||
if self.shape == new_shape: return self
|
||||
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
@@ -188,13 +188,13 @@ class ShapeTracker:
|
||||
if (merged_view := merge_views(cast(View, self.views[-1]), view)) is not None: self.views[-1] = merged_view
|
||||
else: self.views.append(view)
|
||||
|
||||
def _permute(self, axis : Tuple[int, ...]):
|
||||
def _permute(self, axis: Tuple[int, ...]):
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
||||
self.views[-1] = View(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset)
|
||||
|
||||
# except for the negative case, you can build this from the others. invertible in the negative case
|
||||
def _stride(self, mul : Tuple[int, ...]):
|
||||
def _stride(self, mul: Tuple[int, ...]):
|
||||
assert all(isinstance(x, int) for x in mul)
|
||||
strides = tuple(z*m for z,m in zip(self.strides, mul))
|
||||
new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul))
|
||||
@@ -208,16 +208,16 @@ class ShapeTracker:
|
||||
dispatch[op](self, arg)
|
||||
return self
|
||||
|
||||
dispatch : Dict[MovementOps, Callable] = {MovementOps.RESHAPE: ShapeTracker._reshape, MovementOps.EXPAND: ShapeTracker._expand, MovementOps.PAD: ShapeTracker._pad,
|
||||
dispatch: Dict[MovementOps, Callable] = {MovementOps.RESHAPE: ShapeTracker._reshape, MovementOps.EXPAND: ShapeTracker._expand, MovementOps.PAD: ShapeTracker._pad,
|
||||
MovementOps.SHRINK: ShapeTracker._shrink, MovementOps.PERMUTE: ShapeTracker._permute, MovementOps.STRIDE: ShapeTracker._stride}
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]):
|
||||
# Pre-allocate all groups.
|
||||
axis_groups : List[List[int]] = [[] for _ in range(len(new_shape))]
|
||||
axis_groups: List[List[int]] = [[] for _ in range(len(new_shape))]
|
||||
# Index for new_shape and axis_groups.
|
||||
i : int = 0
|
||||
old_shape_i : int = 0
|
||||
i: int = 0
|
||||
old_shape_i: int = 0
|
||||
while old_shape_i < len(old_shape):
|
||||
# 1s exist in new_shape only will lead to empty axes group creations.
|
||||
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
|
||||
|
||||
@@ -177,7 +177,7 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
||||
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
|
||||
return create_node(ret)
|
||||
|
||||
render_python : Dict[Type, Callable] = {
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}<{self.min},{self.max}>" if ctx == "DEBUG" else f"{self.expr}",
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})",
|
||||
|
||||
@@ -30,9 +30,9 @@ import tinygrad.mlops as mlops
|
||||
|
||||
class Tensor:
|
||||
__deletable__ = ('_ctx',)
|
||||
training : ClassVar[bool] = False
|
||||
no_grad : ClassVar[bool] = False
|
||||
default_type : ClassVar[DType] = dtypes.float32
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
|
||||
def __init__(self, data, device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
if isinstance(data, list):
|
||||
@@ -51,14 +51,14 @@ class Tensor:
|
||||
raise RuntimeError(f"can't create Tensor from {data}")
|
||||
|
||||
# tensors have gradients, buffers do not
|
||||
self.grad : Optional[Tensor] = None
|
||||
self.grad: Optional[Tensor] = None
|
||||
|
||||
# NOTE: this can be in three states. False and None: no gradient, True: gradient
|
||||
# None (the default) will be updated to True if it's put in an optimizer
|
||||
self.requires_grad : Optional[bool] = requires_grad
|
||||
self.requires_grad: Optional[bool] = requires_grad
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx : Optional[Function] = None
|
||||
self._ctx: Optional[Function] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata if self.lazydata.realized is None else self.lazydata.realized!r} with grad {(self.grad.lazydata if self.grad else None)!r}>"
|
||||
@@ -128,7 +128,7 @@ class Tensor:
|
||||
# ***** (numpy) rng helper functions *****
|
||||
# TODO: move randomness generation out of numpy
|
||||
|
||||
_rng : ClassVar[np.random.Generator] = np.random.default_rng()
|
||||
_rng: ClassVar[np.random.Generator] = np.random.default_rng()
|
||||
@staticmethod
|
||||
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed)
|
||||
|
||||
@@ -262,7 +262,7 @@ class Tensor:
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
|
||||
axis_ : List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
|
||||
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
|
||||
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
|
||||
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
|
||||
ret = fxn.apply(self, new_shape=tuple(1 if i in axis_ else self.shape[i] for i in range(len(self.shape))))
|
||||
@@ -445,7 +445,7 @@ class Tensor:
|
||||
def dropout(self, p=0.5) -> Tensor:
|
||||
if not Tensor.training: return self
|
||||
# TODO: why is this going through numpy?
|
||||
_mask : np.ndarray = np.asarray(Tensor._rng.binomial(1, 1.0-p, size=self.shape), dtype=np.float32)
|
||||
_mask: np.ndarray = np.asarray(Tensor._rng.binomial(1, 1.0-p, size=self.shape), dtype=np.float32)
|
||||
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
Reference in New Issue
Block a user