From 45ce4de6f384f279cb5fef86f499b55f7608fcff Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 8 Feb 2023 12:48:21 -0600 Subject: [PATCH] improve typing --- tinygrad/llops/ops_cpu.py | 4 ++-- tinygrad/llops/ops_llvm.py | 14 +++++++------- tinygrad/llops/ops_torch.py | 4 ++-- tinygrad/ops.py | 1 + 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 16b1b911b1..11265eeb42 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Final +from typing import ClassVar from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ @@ -10,7 +10,7 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ }) class CPUBuffer(GenericBufExecAST): - fxn_for_op : Final = specialized_fxn_for_op + fxn_for_op : ClassVar = specialized_fxn_for_op def __init__(self, lbuf:np.ndarray): self.buf, self.shape = lbuf, tuple(lbuf.shape) @staticmethod diff --git a/tinygrad/llops/ops_llvm.py b/tinygrad/llops/ops_llvm.py index d81b27943e..f322bbb696 100644 --- a/tinygrad/llops/ops_llvm.py +++ b/tinygrad/llops/ops_llvm.py @@ -2,7 +2,7 @@ from __future__ import annotations import hashlib import math import time -from typing import Tuple, Union, Dict, Any, List +from typing import Tuple, Union, Dict, Any, List, ClassVar from tinygrad.helpers import prod, getenv from tinygrad.shape import ShapeTracker, ZeroView from tinygrad.ops import LazyOp @@ -68,9 +68,9 @@ def idx_deref(builder, buf, ptr, idx): return builder.load(builder.gep(ptr, [idx], inbounds=True)) class LLVM: - target_machine = None - engine = None - optimizer = None + target_machine : ClassVar[llvm.targets.TargetMachine] = None + engine : ClassVar[llvm.executionengine.ExecutionEngine] = None + optimizer : ClassVar[llvm.passmanagers.ModulePassManager] = None def __init__(self): if LLVM.engine is not None: @@ -104,7 +104,7 @@ class LLVM: backing_mod.triple = llvm.get_process_triple() LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine) - def exec(self, module, bufs, op_estimate=0, mem_estimate=0): + def exec(self, module:ir.Module, bufs, op_estimate=0, mem_estimate=0): module.triple = llvm.get_process_triple() module.data_layout = self.engine.target_data llvm_ir = str(module) @@ -146,7 +146,7 @@ class LLVM: # TODO: Refactor LLVMBuffer and GPUBuffer into ShapeTrackedBuffer class LLVMBuffer(ExplicitExecAST): - op_lookup = { + op_lookup : ClassVar = { UnaryOps.NOOP: lambda builder,x: x, UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)), UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), x, ir.Constant(ir.FloatType(), 0)), @@ -161,7 +161,7 @@ class LLVMBuffer(ExplicitExecAST): BinaryOps.POW: lambda builder,x,y: builder.call(builder._block.module.declare_intrinsic('llvm.pow', [ir.FloatType()]), [x,y], fastmath=('fast',)), BinaryOps.CMPEQ: lambda builder,x,y: builder.uitofp(builder.fcmp_ordered("==", x, y, flags=('fast',)), ir.FloatType()) } - start_for_op = { + start_for_op : ClassVar = { ReduceOps.SUM: ir.Constant(ir.FloatType(), 0), ReduceOps.MAX: ir.Constant(ir.FloatType(), -math.inf) } diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 7b7b4d5aa0..de185f961b 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -1,5 +1,5 @@ import torch -from typing import Final +from typing import ClassVar from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op from tinygrad.helpers import getenv @@ -11,7 +11,7 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) class TorchBuffer(GenericBufExecAST): - fxn_for_op : Final = specialized_fxn_for_op + fxn_for_op : ClassVar = specialized_fxn_for_op def __init__(self, lbuf:torch.Tensor): self.buf, self.shape = lbuf, tuple(lbuf.shape) @staticmethod diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1213ec0a8c..e5348f0865 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -67,6 +67,7 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method # used in CPUBuffer and TorchBuffer class GenericBufExecAST(GenericExecAST): # pylint: disable=abstract-method + fxn_for_op : ClassVar # TODO: use generic types here to remove __init__ in specialized classes def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape) def contiguous(self): return self.unary_op(UnaryOps.NOOP)