Description
Hi there! 👋 Following the conversation in #103, I wanted to export the models so they (1) support dynamic shapes and (2) returned the normal information, mainly to run the models with Transformers.js. I got them working, and I've uploaded them to the Hugging Face Hub:
- https://huggingface.co/onnx-community/metric3d-vit-small
- https://huggingface.co/onnx-community/metric3d-vit-large
- https://huggingface.co/onnx-community/metric3d-vit-giant2
(you can find the .onnx weights - both fp32 and fp16) in the onnx
subfolder)
Feel free to use them yourself or add the links to the README for increased visibility! 🤗 PS: I'd also recommend uploading your original pytorch checkpoints to separate repos (instead of a single repo). Let me know if I can help with any of this!
Regarding the export, there were a few things to consider, mainly fixing the modelling code to avoid python type casts (ensuring the dynamic shapes work during tracing). I also made a few modifications to support CPU exports. Here's my conversion code:
import torch
import math
import torch.nn as nn
class NullContext:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
# Do not autocast to bf16 or cuda
torch.autocast = NullContext
class Metric3DExportModel(nn.Module):
"""
The model for exporting to ONNX format. Add custom preprocessing and postprocessing here.
"""
def __init__(self, meta_arch):
super().__init__()
self.meta_arch = meta_arch
self.register_buffer(
"rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1)
)
self.register_buffer(
"rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1)
)
def normalize_image(self, image):
image = image - self.rgb_mean
image = image / self.rgb_std
return image
def forward(self, image):
image = self.normalize_image(image)
with torch.no_grad():
pred_depth, confidence, output_dict = self.meta_arch.inference(
{"input": image}
)
pred_depth = pred_depth.squeeze(1)
pred_normal = output_dict['prediction_normal'][:, :3, :, :] # only available for Metric3Dv2 i.e., ViT models
normal_confidence = output_dict['prediction_normal'][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details
return pred_depth, pred_normal, normal_confidence
def patch_model(model):
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
# Comment out this code (so we always interpolate)
# if npatch == N and w == h:
# return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
if torch.jit.is_tracing():
sqrt_N = N ** 0.5
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, (sqrt_N).to(torch.int64), (sqrt_N).to(torch.int64), dim).permute(0, 3, 1, 2),
size=(w0, h0),
mode="bicubic",
antialias=self.interpolate_antialias,
)
else:
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
mode="bicubic",
antialias=self.interpolate_antialias,
)
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
model.depth_model.encoder.interpolate_pos_encoding = (
interpolate_pos_encoding.__get__(
model.depth_model.encoder, model.depth_model.encoder.__class__
)
)
def get_bins(self, bins_num):
depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num)
depth_bins_vec = torch.exp(depth_bins_vec)
return depth_bins_vec
model.depth_model.decoder.get_bins = (
get_bins.__get__(
model.depth_model.decoder, model.depth_model.decoder.__class__
)
)
return model
# Load model
model_name = "metric3d_vit_small" # or "metric3d_vit_large" or "metric3d_vit_giant2"
model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True)
model.eval()
# Patch model so we can export to ONNX
model = patch_model(model)
export_model = Metric3DExportModel(model)
export_model.eval()
# Export the model
dummy_image = torch.randn([2, 3, 280, 420])
onnx_output = f"{model_name}.onnx"
torch.onnx.export(
export_model,
(dummy_image, ),
onnx_output,
input_names=["pixel_values"],
output_names=["predicted_depth", "predicted_normal", "normal_confidence"],
opset_version=11,
dynamic_axes= {
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
"predicted_depth": {0: "batch_size", 1: "height", 2: "width"},
"predicted_normal": {0: "batch_size", 2: "height", 3: "width"},
"normal_confidence": {0: "batch_size", 1: "height", 2: "width"},
}
)