viz fixups + scheduler option [pr] (#8557)

This commit is contained in:
George Hotz
2025-01-10 15:09:31 -08:00
committed by GitHub
parent f457cb64d6
commit 70fa65cd95
7 changed files with 8 additions and 11 deletions

View File

@@ -419,7 +419,7 @@ def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype)
return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype).contiguous()
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)

View File

@@ -456,7 +456,7 @@ def realize_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> Non
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
# early realize before expand
if resolve(prod(src.shape) < prod(st.shape)): return realize(ctx, b, src)
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
# otherwise safety check pads
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -11,6 +11,4 @@ fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.cs
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js.map"
fetch "unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css"

View File

@@ -10,7 +10,6 @@
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"></script>
<script src="assets/cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js"></script>
<link rel="stylesheet" href="assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css" />
<style>
* {
@@ -276,7 +275,9 @@
href: "vscode://file"+parts.join("/"), style: "font-family: monospace; margin: 4px 0;" });
const highlightedCodeBlock = (code, lang, wrap) => {
const pre = Object.assign(document.createElement("pre"), {className: wrap ? "wrap" : ""});
const codeEl= Object.assign(document.createElement("code"), {className: `language-${lang} code-block`, textContent: DOMPurify.sanitize(code)});
const codeEl = Object.assign(document.createElement("code"), {
// NOTE: since code is in textContent, we don't need DOMPurify
className: `language-${lang} code-block`, textContent: code});
pre.appendChild(codeEl);
hljs.highlightElement(codeEl);
return pre;
@@ -363,7 +364,7 @@
let code = ret.uops[currentRewrite];
let lang = "python"
if (ret.kernel_code != null) {
code = ret.kernel_code.replaceAll("<", "&lt;").replaceAll(">", "&gt;");
code = ret.kernel_code;
lang = "cpp";
}
const codeBlock = highlightedCodeBlock(code, lang, false);

View File

@@ -71,7 +71,8 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]:
if u in excluded: continue
argst = str(u.arg)
if u.op is Ops.VIEW:
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f" / {v.offset}" if v.offset is not None else "") for v in unwrap(u.st).views]))
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+
("" if v.offset == 0 else f" / {v.offset}") for v in unwrap(u.st).views]))
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x in excluded: