mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fix for undefined dtypes in jit during loading defaults (#2114)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -2172,12 +2172,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
Y, stride_yk, stride_yn,
|
||||
W, stride_wn, stride_wl,
|
||||
Z, stride_zm, stride_zn,
|
||||
out_dtype: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
|
||||
ALLOW_TF32: tl.constexpr,
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
COL_A: tl.constexpr, COL_B: tl.constexpr):
|
||||
COL_A: tl.constexpr, COL_B: tl.constexpr,
|
||||
out_dtype: tl.constexpr = tl.float32):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
off_l = tl.arange(0, BLOCK_N)
|
||||
@@ -2251,7 +2251,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
out_dtype,
|
||||
COL_A=col_a, COL_B=col_b,
|
||||
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
|
||||
ADD_MATRIX=epilogue == 'add-matrix',
|
||||
@@ -2260,7 +2259,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
DO_SOFTMAX=epilogue == 'softmax',
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps, num_ctas=num_ctas)
|
||||
num_warps=num_warps, num_ctas=num_ctas,
|
||||
out_dtype=out_dtype)
|
||||
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
|
||||
ptx = pgm.asm["ptx"]
|
||||
start = ptx.find("shfl.sync")
|
||||
|
||||
@@ -116,5 +116,3 @@ def test_line_info(func: str):
|
||||
assert (check_file_lines(file_lines, "standard.py", 33))
|
||||
assert (check_file_lines(file_lines, "standard.py", 34))
|
||||
assert (check_file_lines(file_lines, "standard.py", 36))
|
||||
# core.py is changed frequently, so we only check if it exists
|
||||
assert (check_file_lines(file_lines, "core.py", -1))
|
||||
|
||||
@@ -4,11 +4,21 @@
|
||||
from . import math
|
||||
from . import extra
|
||||
from .standard import (
|
||||
argmax,
|
||||
argmin,
|
||||
cdiv,
|
||||
cumprod,
|
||||
cumsum,
|
||||
max,
|
||||
maximum,
|
||||
min,
|
||||
minimum,
|
||||
sigmoid,
|
||||
softmax,
|
||||
sum,
|
||||
ravel,
|
||||
swizzle2d,
|
||||
xor_sum,
|
||||
zeros,
|
||||
zeros_like,
|
||||
)
|
||||
@@ -17,8 +27,6 @@ from .core import (
|
||||
abs,
|
||||
advance,
|
||||
arange,
|
||||
argmin,
|
||||
argmax,
|
||||
associative_scan,
|
||||
atomic_add,
|
||||
atomic_and,
|
||||
@@ -35,8 +43,6 @@ from .core import (
|
||||
cat,
|
||||
constexpr,
|
||||
cos,
|
||||
cumprod,
|
||||
cumsum,
|
||||
debug_barrier,
|
||||
device_assert,
|
||||
device_print,
|
||||
@@ -63,12 +69,8 @@ from .core import (
|
||||
load,
|
||||
log,
|
||||
make_block_ptr,
|
||||
max,
|
||||
max_constancy,
|
||||
max_contiguous,
|
||||
maximum,
|
||||
min,
|
||||
minimum,
|
||||
multiple_of,
|
||||
num_programs,
|
||||
pi32_t,
|
||||
@@ -81,7 +83,6 @@ from .core import (
|
||||
static_assert,
|
||||
static_print,
|
||||
store,
|
||||
sum,
|
||||
static_range,
|
||||
tensor,
|
||||
trans,
|
||||
@@ -94,7 +95,6 @@ from .core import (
|
||||
view,
|
||||
void,
|
||||
where,
|
||||
xor_sum,
|
||||
)
|
||||
from .random import (
|
||||
pair_uniform_to_normal,
|
||||
|
||||
@@ -6,8 +6,7 @@ from functools import wraps
|
||||
from typing import Callable, List, Sequence, TypeVar
|
||||
|
||||
from .._C.libtriton.triton import ir
|
||||
from ..runtime.jit import jit
|
||||
from . import math, semantic
|
||||
from . import semantic
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
@@ -205,6 +204,10 @@ class dtype:
|
||||
def is_bool(self):
|
||||
return self.is_int1()
|
||||
|
||||
@staticmethod
|
||||
def is_dtype(type_str):
|
||||
return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES
|
||||
|
||||
@staticmethod
|
||||
def is_void():
|
||||
raise RuntimeError("Not implemented")
|
||||
@@ -1380,170 +1383,6 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None
|
||||
return rvalue, rindices
|
||||
|
||||
|
||||
@jit
|
||||
def minimum(x, y):
|
||||
"""
|
||||
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
||||
|
||||
:param input: the first input tensor
|
||||
:type input: Block
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return math.min(x, y)
|
||||
|
||||
|
||||
@jit
|
||||
def maximum(x, y):
|
||||
"""
|
||||
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
||||
|
||||
:param input: the first input tensor
|
||||
:type input: Block
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return math.max(x, y)
|
||||
|
||||
# max and argmax
|
||||
|
||||
|
||||
@jit
|
||||
def _argmax_combine(value1, index1, value2, index2, tie_break_left):
|
||||
if tie_break_left:
|
||||
tie = value1 == value2 and index1 < index2
|
||||
else:
|
||||
tie = False
|
||||
gt = value1 > value2 or tie
|
||||
v_ret = where(gt, value1, value2)
|
||||
i_ret = where(gt, index1, index2)
|
||||
return v_ret, i_ret
|
||||
|
||||
|
||||
@jit
|
||||
def _argmax_combine_tie_break_left(value1, index1, value2, index2):
|
||||
return _argmax_combine(value1, index1, value2, index2, True)
|
||||
|
||||
|
||||
@jit
|
||||
def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
return _argmax_combine(value1, index1, value2, index2, False)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("maximum",
|
||||
return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = _promote_reduction_input(input)
|
||||
if return_indices:
|
||||
if return_indices_tie_break_left:
|
||||
return _reduce_with_indices(input, axis, _argmax_combine_tie_break_left)
|
||||
else:
|
||||
return _reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
|
||||
else:
|
||||
if constexpr(input.dtype.primitive_bitwidth) < 32:
|
||||
if constexpr(input.dtype.is_floating()):
|
||||
input = input.to(float32)
|
||||
else:
|
||||
assert input.dtype.is_integer_type()
|
||||
input = input.to(int32)
|
||||
return reduce(input, axis, maximum)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
|
||||
def argmax(input, axis, tie_break_left=True):
|
||||
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
|
||||
# min and argmin
|
||||
|
||||
|
||||
@jit
|
||||
def _argmin_combine(value1, index1, value2, index2, tie_break_left):
|
||||
if tie_break_left:
|
||||
tie = value1 == value2 and index1 < index2
|
||||
else:
|
||||
tie = False
|
||||
lt = value1 < value2 or tie
|
||||
value_ret = where(lt, value1, value2)
|
||||
index_ret = where(lt, index1, index2)
|
||||
return value_ret, index_ret
|
||||
|
||||
|
||||
@jit
|
||||
def _argmin_combine_tie_break_left(value1, index1, value2, index2):
|
||||
return _argmin_combine(value1, index1, value2, index2, True)
|
||||
|
||||
|
||||
@jit
|
||||
def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
return _argmin_combine(value1, index1, value2, index2, False)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("minimum",
|
||||
return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = _promote_reduction_input(input)
|
||||
if return_indices:
|
||||
if return_indices_tie_break_left:
|
||||
return _reduce_with_indices(input, axis, _argmin_combine_tie_break_left)
|
||||
else:
|
||||
return _reduce_with_indices(input, axis, _argmin_combine_tie_break_fast)
|
||||
else:
|
||||
if constexpr(input.dtype.primitive_bitwidth) < 32:
|
||||
if constexpr(input.dtype.is_floating()):
|
||||
input = input.to(float32)
|
||||
else:
|
||||
assert input.dtype.is_integer_type()
|
||||
input = input.to(int32)
|
||||
return reduce(input, axis, minimum)
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("minimum index",
|
||||
tie_break_arg="tie_break_left")
|
||||
def argmin(input, axis, tie_break_left=True):
|
||||
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
|
||||
|
||||
@jit
|
||||
def _sum_combine(a, b):
|
||||
return a + b
|
||||
|
||||
# sum
|
||||
|
||||
|
||||
@jit
|
||||
@_add_reduction_docstr("sum")
|
||||
def sum(input, axis=None):
|
||||
input = _promote_reduction_input(input)
|
||||
return reduce(input, axis, _sum_combine)
|
||||
|
||||
|
||||
@jit
|
||||
def _xor_combine(a, b):
|
||||
return a ^ b
|
||||
|
||||
|
||||
# xor sum
|
||||
|
||||
@builtin
|
||||
@_add_reduction_docstr("xor sum")
|
||||
def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
scalar_ty = input.type.scalar
|
||||
if not scalar_ty.is_int():
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
input = _promote_reduction_input(input, _builder=_builder)
|
||||
return reduce(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Scans
|
||||
# -----------------------
|
||||
@@ -1594,31 +1433,6 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
axis = _constexpr_to_value(axis)
|
||||
return semantic.associative_scan(input, axis, make_combine_region, _builder)
|
||||
|
||||
# cumsum
|
||||
|
||||
|
||||
@jit
|
||||
@_add_scan_docstr("cumsum")
|
||||
def cumsum(input, axis=0):
|
||||
# todo rename this to a generic function name
|
||||
input = _promote_reduction_input(input)
|
||||
return associative_scan(input, axis, _sum_combine)
|
||||
|
||||
# cumprod
|
||||
|
||||
|
||||
@jit
|
||||
def _prod_combine(a, b):
|
||||
return a * b
|
||||
|
||||
|
||||
@jit
|
||||
@_add_scan_docstr("cumprod")
|
||||
def cumprod(input, axis=0):
|
||||
# todo rename this to a generic function name
|
||||
input = _promote_reduction_input(input)
|
||||
return associative_scan(input, axis, _prod_combine)
|
||||
|
||||
# -----------------------
|
||||
# Compiler Hint Ops
|
||||
# -----------------------
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ..runtime.jit import jit
|
||||
from . import core as tl
|
||||
from . import standard
|
||||
|
||||
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
|
||||
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
|
||||
@@ -141,7 +142,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
||||
@jit
|
||||
def pair_uniform_to_normal(u1, u2):
|
||||
"""Box-Muller transform"""
|
||||
u1 = tl.maximum(1.0e-7, u1)
|
||||
u1 = standard.maximum(1.0e-7, u1)
|
||||
th = 6.283185307179586 * u2
|
||||
r = tl.sqrt(-2.0 * tl.log(u1))
|
||||
return r * tl.cos(th), r * tl.sin(th)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..runtime.jit import jit
|
||||
from . import core
|
||||
from . import core, math
|
||||
|
||||
# -----------------------
|
||||
# Standard library
|
||||
@@ -30,9 +30,9 @@ def sigmoid(x):
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("softmax")
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - core.max(x, 0)
|
||||
z = x - max(x, 0)
|
||||
num = core.exp(z)
|
||||
den = core.sum(num, 0)
|
||||
den = sum(num, 0)
|
||||
return core.fdiv(num, den, ieee_rounding)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
# row-index of the first element of this group
|
||||
off_i = group_id * size_g
|
||||
# last group may have fewer rows
|
||||
size_g = core.minimum(size_i - off_i, size_g)
|
||||
size_g = minimum(size_i - off_i, size_g)
|
||||
# new row and column indices
|
||||
new_i = off_i + (ij % size_g)
|
||||
new_j = (ij % size_gj) // size_g
|
||||
@@ -96,3 +96,192 @@ def zeros(shape, dtype):
|
||||
@jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
|
||||
|
||||
@jit
|
||||
def minimum(x, y):
|
||||
"""
|
||||
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
||||
|
||||
:param input: the first input tensor
|
||||
:type input: Block
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return math.min(x, y)
|
||||
|
||||
|
||||
@jit
|
||||
def maximum(x, y):
|
||||
"""
|
||||
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
||||
|
||||
:param input: the first input tensor
|
||||
:type input: Block
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return math.max(x, y)
|
||||
|
||||
# max and argmax
|
||||
|
||||
|
||||
@jit
|
||||
def _argmax_combine(value1, index1, value2, index2, tie_break_left):
|
||||
if tie_break_left:
|
||||
tie = value1 == value2 and index1 < index2
|
||||
else:
|
||||
tie = False
|
||||
gt = value1 > value2 or tie
|
||||
v_ret = core.where(gt, value1, value2)
|
||||
i_ret = core.where(gt, index1, index2)
|
||||
return v_ret, i_ret
|
||||
|
||||
|
||||
@jit
|
||||
def _argmax_combine_tie_break_left(value1, index1, value2, index2):
|
||||
return _argmax_combine(value1, index1, value2, index2, True)
|
||||
|
||||
|
||||
@jit
|
||||
def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
return _argmax_combine(value1, index1, value2, index2, False)
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("maximum",
|
||||
return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
if return_indices:
|
||||
if return_indices_tie_break_left:
|
||||
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left)
|
||||
else:
|
||||
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
|
||||
else:
|
||||
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
|
||||
if core.constexpr(input.dtype.is_floating()):
|
||||
input = input.to(core.float32)
|
||||
else:
|
||||
assert input.dtype.is_integer_type()
|
||||
input = input.to(core.int32)
|
||||
return core.reduce(input, axis, maximum)
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
|
||||
def argmax(input, axis, tie_break_left=True):
|
||||
(_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
|
||||
# min and argmin
|
||||
|
||||
|
||||
@jit
|
||||
def _argmin_combine(value1, index1, value2, index2, tie_break_left):
|
||||
if tie_break_left:
|
||||
tie = value1 == value2 and index1 < index2
|
||||
else:
|
||||
tie = False
|
||||
lt = value1 < value2 or tie
|
||||
value_ret = core.where(lt, value1, value2)
|
||||
index_ret = core.where(lt, index1, index2)
|
||||
return value_ret, index_ret
|
||||
|
||||
|
||||
@jit
|
||||
def _argmin_combine_tie_break_left(value1, index1, value2, index2):
|
||||
return _argmin_combine(value1, index1, value2, index2, True)
|
||||
|
||||
|
||||
@jit
|
||||
def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
|
||||
return _argmin_combine(value1, index1, value2, index2, False)
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum",
|
||||
return_indices_arg="return_indices",
|
||||
tie_break_arg="return_indices_tie_break_left")
|
||||
def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True):
|
||||
input = core._promote_reduction_input(input)
|
||||
if return_indices:
|
||||
if return_indices_tie_break_left:
|
||||
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left)
|
||||
else:
|
||||
return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast)
|
||||
else:
|
||||
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
|
||||
if core.constexpr(input.dtype.is_floating()):
|
||||
input = input.to(core.float32)
|
||||
else:
|
||||
assert input.dtype.is_integer_type()
|
||||
input = input.to(core.int32)
|
||||
return core.reduce(input, axis, minimum)
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("minimum index",
|
||||
tie_break_arg="tie_break_left")
|
||||
def argmin(input, axis, tie_break_left=True):
|
||||
_, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left)
|
||||
return ret
|
||||
|
||||
|
||||
@jit
|
||||
def _sum_combine(a, b):
|
||||
return a + b
|
||||
|
||||
# sum
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_reduction_docstr("sum")
|
||||
def sum(input, axis=None):
|
||||
input = core._promote_reduction_input(input)
|
||||
return core.reduce(input, axis, _sum_combine)
|
||||
|
||||
|
||||
@jit
|
||||
def _xor_combine(a, b):
|
||||
return a ^ b
|
||||
|
||||
# xor sum
|
||||
|
||||
|
||||
@core.builtin
|
||||
@core._add_reduction_docstr("xor sum")
|
||||
def xor_sum(input, axis=None, _builder=None, _generator=None):
|
||||
scalar_ty = input.type.scalar
|
||||
if not scalar_ty.is_int():
|
||||
raise ValueError("xor_sum only supported for integers")
|
||||
|
||||
input = core._promote_reduction_input(input, _builder=_builder)
|
||||
return core.reduce(input, axis, _xor_combine,
|
||||
_builder=_builder, _generator=_generator)
|
||||
|
||||
# cumsum
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_scan_docstr("cumsum")
|
||||
def cumsum(input, axis=0):
|
||||
# todo rename this to a generic function name
|
||||
input = core._promote_reduction_input(input)
|
||||
return core.associative_scan(input, axis, _sum_combine)
|
||||
|
||||
# cumprod
|
||||
|
||||
|
||||
@jit
|
||||
def _prod_combine(a, b):
|
||||
return a * b
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_scan_docstr("cumprod")
|
||||
def cumprod(input, axis=0):
|
||||
# todo rename this to a generic function name
|
||||
input = core._promote_reduction_input(input)
|
||||
return core.associative_scan(input, axis, _prod_combine)
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union,
|
||||
|
||||
from .._C.libtriton.triton import TMAInfos
|
||||
from ..common.backend import get_backend, path_to_ptxas
|
||||
from ..language.core import dtype
|
||||
|
||||
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TRITON_VERSION = "2.1.0"
|
||||
@@ -358,10 +359,11 @@ class JITFunction(KernelInterface[T]):
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults))
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
|
||||
src = f"""
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
|
||||
Reference in New Issue
Block a user