mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 23:08:19 -05:00
Compare commits
5 Commits
diffusers-
...
20230804.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96ced18f90 | ||
|
|
167c6cc349 | ||
|
|
5a091ae3f0 | ||
|
|
cefcc45873 | ||
|
|
e2b4de8c0a |
@@ -690,8 +690,14 @@ class ShardedVicuna(VicunaBase):
|
||||
# f_ = open(mlir_path, "wb")
|
||||
# f_.write(bytecode)
|
||||
# f_.close()
|
||||
command = f"gsutil cp gs://shark_tank/elias/compressed_sv/lmhead.mlir lmhead.mlir"
|
||||
subprocess.check_call(command.split())
|
||||
# command = f"gsutil cp gs://shark_tank/elias/compressed_sv/lmhead.mlir lmhead.mlir"
|
||||
# subprocess.check_call(command.split())
|
||||
filepath = Path("lmhead.mlir")
|
||||
download_public_file(
|
||||
"gs://shark_tank/elias/compressed_sv/lmhead.mlir",
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"lmhead.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
@@ -732,8 +738,14 @@ class ShardedVicuna(VicunaBase):
|
||||
# use_tracing=False,
|
||||
# verbose=False,
|
||||
# )
|
||||
command = f"gsutil cp gs://shark_tank/elias/compressed_sv/norm.mlir norm.mlir"
|
||||
subprocess.check_call(command.split())
|
||||
# command = f"gsutil cp gs://shark_tank/elias/compressed_sv/norm.mlir norm.mlir"
|
||||
# subprocess.check_call(command.split())
|
||||
filepath = Path("norm.mlir")
|
||||
download_public_file(
|
||||
"gs://shark_tank/elias/compressed_sv/norm.mlir",
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"norm.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
@@ -779,8 +791,14 @@ class ShardedVicuna(VicunaBase):
|
||||
# f_ = open(mlir_path, "wb")
|
||||
# f_.write(bytecode)
|
||||
# f_.close()
|
||||
command = f"gsutil cp gs://shark_tank/elias/compressed_sv/embedding.mlir embedding.mlir"
|
||||
subprocess.check_call(command.split())
|
||||
# command = f"gsutil cp gs://shark_tank/elias/compressed_sv/embedding.mlir embedding.mlir"
|
||||
# subprocess.check_call(command.split())
|
||||
filepath = Path("embedding.mlir")
|
||||
download_public_file(
|
||||
"gs://shark_tank/elias/compressed_sv/embedding.mlir",
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
f_ = open(f"embedding.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
@@ -963,6 +981,8 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
@@ -986,9 +1006,15 @@ class ShardedVicuna(VicunaBase):
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
else:
|
||||
command = f"gsutil cp gs://shark_tank/elias/compressed_sv/{idx}_full.mlir {idx}_full.mlir"
|
||||
# command = f"gsutil cp gs://shark_tank/elias/compressed_sv/{idx}_full.mlir {idx}_full.mlir"
|
||||
|
||||
subprocess.check_call(command.split())
|
||||
# subprocess.check_call(command.split())
|
||||
filepath = Path(f"{idx}_full.mlir")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/elias/compressed_sv/{idx}_full.mlir",
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
|
||||
f_ = open(f"{idx}_full.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
@@ -1026,6 +1052,8 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
@@ -1150,7 +1178,7 @@ class ShardedVicuna(VicunaBase):
|
||||
layers0 = [layers00, layers01, layers02, layers03]
|
||||
layers1 = [layers10, layers11, layers12, layers13]
|
||||
|
||||
_, modules = self.compile_to_vmfb_one_model(
|
||||
_, modules = self.compile_to_vmfb_one_model4(
|
||||
placeholder_input0,
|
||||
layers0,
|
||||
placeholder_input1,
|
||||
@@ -1177,6 +1205,9 @@ class ShardedVicuna(VicunaBase):
|
||||
return self.get_sharded_model(
|
||||
device=device, compressed=self.compressed
|
||||
)
|
||||
return self.get_sharded_model(
|
||||
device=device, compressed=self.compressed
|
||||
)
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
# TODO: refactor for cleaner integration
|
||||
@@ -1617,6 +1648,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
],
|
||||
)
|
||||
print("Saved vic vmfb at ", str(path))
|
||||
|
||||
Reference in New Issue
Block a user