mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 17:38:06 -05:00
viz fixups + scheduler option [pr] (#8557)
This commit is contained in:
@@ -419,7 +419,7 @@ def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
|
||||
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
||||
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
||||
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
|
||||
return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype)
|
||||
return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype).contiguous()
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
|
||||
@@ -456,7 +456,7 @@ def realize_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> Non
|
||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
|
||||
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
|
||||
# early realize before expand
|
||||
if resolve(prod(src.shape) < prod(st.shape)): return realize(ctx, b, src)
|
||||
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -11,6 +11,4 @@ fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.cs
|
||||
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"
|
||||
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"
|
||||
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"
|
||||
fetch "cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js"
|
||||
fetch "cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js.map"
|
||||
fetch "unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css"
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
|
||||
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
|
||||
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js"></script>
|
||||
<script src="assets/cdnjs.cloudflare.com/ajax/libs/dompurify/1.0.3/purify.min.js"></script>
|
||||
<link rel="stylesheet" href="assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css" />
|
||||
<style>
|
||||
* {
|
||||
@@ -276,7 +275,9 @@
|
||||
href: "vscode://file"+parts.join("/"), style: "font-family: monospace; margin: 4px 0;" });
|
||||
const highlightedCodeBlock = (code, lang, wrap) => {
|
||||
const pre = Object.assign(document.createElement("pre"), {className: wrap ? "wrap" : ""});
|
||||
const codeEl= Object.assign(document.createElement("code"), {className: `language-${lang} code-block`, textContent: DOMPurify.sanitize(code)});
|
||||
const codeEl = Object.assign(document.createElement("code"), {
|
||||
// NOTE: since code is in textContent, we don't need DOMPurify
|
||||
className: `language-${lang} code-block`, textContent: code});
|
||||
pre.appendChild(codeEl);
|
||||
hljs.highlightElement(codeEl);
|
||||
return pre;
|
||||
@@ -363,7 +364,7 @@
|
||||
let code = ret.uops[currentRewrite];
|
||||
let lang = "python"
|
||||
if (ret.kernel_code != null) {
|
||||
code = ret.kernel_code.replaceAll("<", "<").replaceAll(">", ">");
|
||||
code = ret.kernel_code;
|
||||
lang = "cpp";
|
||||
}
|
||||
const codeBlock = highlightedCodeBlock(code, lang, false);
|
||||
|
||||
@@ -71,7 +71,8 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]:
|
||||
if u in excluded: continue
|
||||
argst = str(u.arg)
|
||||
if u.op is Ops.VIEW:
|
||||
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f" / {v.offset}" if v.offset is not None else "") for v in unwrap(u.st).views]))
|
||||
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+
|
||||
("" if v.offset == 0 else f" / {v.offset}") for v in unwrap(u.st).views]))
|
||||
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
|
||||
for idx,x in enumerate(u.src):
|
||||
if x in excluded:
|
||||
|
||||
Reference in New Issue
Block a user