mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
stable diffusion profiling (#13441)
* stable diffusion profiling Signed-off-by: George Hotz <geohot@gmail.com> * profile_marker * profile per step * fix slow Context * profile that --------- Signed-off-by: George Hotz <geohot@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Dict, Any
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from tinygrad import Device, GlobalCounters, dtypes, Tensor, TinyJit
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten
|
||||
from tinygrad.helpers import Timing, Context, getenv, fetch, colored, tqdm, flatten, profile_marker
|
||||
from tinygrad.nn import Conv2d, GroupNorm
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from extra.models.clip import Closed, Tokenizer, FrozenOpenClipEmbedder
|
||||
@@ -266,9 +266,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--fakeweights', action='store_true', help="Skip loading checkpoints and use fake weights")
|
||||
args = parser.parse_args()
|
||||
|
||||
profile_marker("create model")
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
profile_marker("load in weights")
|
||||
with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
|
||||
if not args.fakeweights:
|
||||
model_bin = fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt')
|
||||
@@ -281,12 +282,13 @@ if __name__ == "__main__":
|
||||
|
||||
Tensor.realize(*get_state_dict(model).values())
|
||||
|
||||
# run through CLIP to get context
|
||||
profile_marker("run clip (conditional)")
|
||||
tokenizer = Tokenizer.ClipTokenizer()
|
||||
prompt = Tensor([tokenizer.encode(args.prompt)])
|
||||
context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
print("got CLIP context", context.shape)
|
||||
|
||||
profile_marker("run clip (unconditional)")
|
||||
prompt = Tensor([tokenizer.encode("")])
|
||||
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
print("got unconditional CLIP context", unconditional_context.shape)
|
||||
@@ -310,6 +312,7 @@ if __name__ == "__main__":
|
||||
step_times = []
|
||||
with Context(BEAM=getenv("LATEBEAM")):
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
profile_marker(f"step {len(timesteps)-index-1}")
|
||||
GlobalCounters.reset()
|
||||
st = time.perf_counter_ns()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
@@ -319,24 +322,26 @@ if __name__ == "__main__":
|
||||
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
step_times.append((time.perf_counter_ns() - st)*1e-6)
|
||||
# done with diffusion model
|
||||
del run
|
||||
del model.model
|
||||
|
||||
if (assert_time:=getenv("ASSERT_MIN_STEP_TIME")):
|
||||
min_time = min(step_times)
|
||||
assert min_time < assert_time, f"Speed regression, expected min step time of < {assert_time} ms but took: {min_time} ms"
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.decode(latent)
|
||||
profile_marker("run decoder") # upsample latent space to image with autoencoder
|
||||
x = model.decode(latent).realize()
|
||||
print(x.shape)
|
||||
|
||||
# save image
|
||||
profile_marker("save image")
|
||||
im = Image.fromarray(x.numpy())
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
# Open image.
|
||||
if not args.noshow: im.show()
|
||||
|
||||
# validation!
|
||||
if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5:
|
||||
profile_marker("validate")
|
||||
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
|
||||
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
|
||||
assert distance < 3e-3, colored(f"validation failed with {distance=}", "red") # higher distance with WINO
|
||||
|
||||
@@ -147,8 +147,10 @@ def temp(x:str, append_user:bool=False) -> str:
|
||||
class Context(contextlib.ContextDecorator):
|
||||
def __init__(self, **kwargs): self.kwargs = kwargs
|
||||
def __enter__(self):
|
||||
self.old_context:dict[str, int] = {k:v.value for k,v in ContextVar._cache.items()}
|
||||
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v
|
||||
self.old_context:dict[str, int] = {}
|
||||
for k,v in self.kwargs.items():
|
||||
self.old_context[k] = ContextVar._cache[k].value
|
||||
ContextVar._cache[k].value = v
|
||||
def __exit__(self, *args):
|
||||
for k,v in self.old_context.items(): ContextVar._cache[k].value = v
|
||||
|
||||
@@ -279,7 +281,7 @@ class ProfilePointEvent(ProfileEvent):
|
||||
|
||||
cpu_events:list[ProfileEvent] = []
|
||||
@contextlib.contextmanager
|
||||
def cpu_profile(name:str|TracingKey, device="CPU", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]:
|
||||
def cpu_profile(name:str|TracingKey, device="TINY", is_copy=False, display=True) -> Generator[ProfileRangeEvent, None, None]:
|
||||
res = ProfileRangeEvent(device, name, perf_counter_us(), is_copy=is_copy)
|
||||
try: yield res
|
||||
finally:
|
||||
@@ -382,7 +384,11 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
|
||||
|
||||
# *** Exec helpers
|
||||
|
||||
def system(cmd, **kwargs): return subprocess.check_output(cmd.split(), **kwargs).decode().strip()
|
||||
def system(cmd:str, **kwargs) -> str:
|
||||
st = time.perf_counter()
|
||||
ret = subprocess.check_output(cmd.split(), **kwargs).decode().strip()
|
||||
if DEBUG >= 1: print(f"system: '{cmd}' returned {len(ret)} bytes in {(time.perf_counter() - st)*1e3:.2f} ms")
|
||||
return ret
|
||||
|
||||
def cpu_objdump(lib, objdump_tool='objdump'):
|
||||
with tempfile.NamedTemporaryFile(delete=True) as f:
|
||||
|
||||
@@ -260,8 +260,8 @@ class Tensor(OpMixin):
|
||||
_apply_map_to_tensors(remove_assign_map, name="Remove After")
|
||||
|
||||
# create the schedule
|
||||
schedule, var_vals = create_schedule_with_vars(sink)
|
||||
schedule = memory_planner(schedule)
|
||||
with cpu_profile(TracingKey("toposort schedule")): schedule, var_vals = create_schedule_with_vars(sink)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
|
||||
return schedule, var_vals
|
||||
|
||||
|
||||
@@ -275,7 +275,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return tuple(1 if i in axis_arg else s for i,s in enumerate(ps))
|
||||
|
||||
# elementwise ops keep the shape the same. all inputs with shape must match
|
||||
if self.op in (GroupOp.Elementwise-{Ops.BITCAST}).union({Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
if self.op in GroupOp.ALU.union({Ops.CAST, Ops.COPY, Ops.ASSIGN, Ops.NOOP, Ops.GROUP, Ops.SINK, Ops.ALLREDUCE, Ops.STORE}):
|
||||
# TODO: remove this hack for 3 op assign
|
||||
input_shapes = [x._shape for x in (self.src[:2] if self.op is Ops.ASSIGN else self.src) if x._shape is not None]
|
||||
if len(input_shapes) == 0: return None
|
||||
|
||||
Reference in New Issue
Block a user