mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add name support to fetch (#2407)
* add name support * use fetch in gpt2 * remove requests from main lib, networkx also optional * umm, keep that assert * updates to fetch * i love the walrus so much * stop bundling mnist with tinygrad * err, https * download cache names * add DOWNLOAD_CACHE_VERSION * need env. * ugh, wrong path * replace get_child
This commit is contained in:
@@ -4,7 +4,7 @@ from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from extra.utils import download_file
|
||||
from tinygrad.helpers import fetch
|
||||
from typing import NamedTuple, Any, List
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
@@ -28,8 +28,6 @@ def convert_f32_to_f16(input_file, output_file):
|
||||
front_float16_values.tofile(f)
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
FILENAME = Path(__file__).parent.parent.parent.parent / "weights/sd-v1-4.ckpt"
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
@@ -40,7 +38,7 @@ def split_safetensor(fn):
|
||||
if (metadata[k]["data_offsets"][0] < text_model_offset):
|
||||
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
|
||||
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
|
||||
|
||||
|
||||
last_offset = 0
|
||||
part_end_offsets = []
|
||||
|
||||
@@ -51,7 +49,7 @@ def split_safetensor(fn):
|
||||
break
|
||||
|
||||
part_offset = offset - last_offset
|
||||
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
part_end_offsets.append(8+json_len+offset)
|
||||
last_offset = offset
|
||||
@@ -60,7 +58,7 @@ def split_safetensor(fn):
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
part_end_offsets.append(text_model_start+8+json_len)
|
||||
cur_pos = 0
|
||||
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
with open(f'./net_part{i}.safetensors', "wb+") as f:
|
||||
f.write(net_bytes[cur_pos:end_pos])
|
||||
@@ -68,7 +66,7 @@ def split_safetensor(fn):
|
||||
|
||||
with open(f'./net_textmodel.safetensors', "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+8+json_len:])
|
||||
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -81,8 +79,7 @@ if __name__ == "__main__":
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
|
||||
class Step(NamedTuple):
|
||||
name: str = ""
|
||||
@@ -90,11 +87,11 @@ if __name__ == "__main__":
|
||||
forward: Any = None
|
||||
|
||||
sub_steps = [
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
|
||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
|
||||
]
|
||||
|
||||
|
||||
prg = ""
|
||||
|
||||
def compile_step(model, step: Step):
|
||||
@@ -109,7 +106,7 @@ if __name__ == "__main__":
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
|
||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
|
||||
return f"""\n var {step.name} = function() {{
|
||||
|
||||
|
||||
{kernel_code}
|
||||
|
||||
return {{
|
||||
@@ -117,7 +114,7 @@ if __name__ == "__main__":
|
||||
const metadata = getTensorMetadata(safetensor[0]);
|
||||
|
||||
{bufs}
|
||||
|
||||
|
||||
{gpu_write_bufs}
|
||||
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
@@ -140,7 +137,7 @@ if __name__ == "__main__":
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user