Skip to content

Improved ONNX support with dynamic shapes #117

Closed
@xenova

Description

@xenova

Hi there! 👋 Following the conversation in #103, I wanted to export the models so they (1) support dynamic shapes and (2) returned the normal information, mainly to run the models with Transformers.js. I got them working, and I've uploaded them to the Hugging Face Hub:

(you can find the .onnx weights - both fp32 and fp16) in the onnx subfolder)

Feel free to use them yourself or add the links to the README for increased visibility! 🤗 PS: I'd also recommend uploading your original pytorch checkpoints to separate repos (instead of a single repo). Let me know if I can help with any of this!

Regarding the export, there were a few things to consider, mainly fixing the modelling code to avoid python type casts (ensuring the dynamic shapes work during tracing). I also made a few modifications to support CPU exports. Here's my conversion code:

import torch
import math
import torch.nn as nn

class NullContext:
  def __init__(self, *args, **kwargs):
    pass

  def __enter__(self):
    pass

  def __exit__(self, exc_type, exc_value, traceback):
    pass

# Do not autocast to bf16 or cuda
torch.autocast = NullContext

class Metric3DExportModel(nn.Module):
    """
    The model for exporting to ONNX format. Add custom preprocessing and postprocessing here.
    """

    def __init__(self, meta_arch):
        super().__init__()
        self.meta_arch = meta_arch
        self.register_buffer(
            "rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1)
        )
        self.register_buffer(
            "rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1)
        )

    def normalize_image(self, image):
        image = image - self.rgb_mean
        image = image / self.rgb_std
        return image

    def forward(self, image):
        image = self.normalize_image(image)
        with torch.no_grad():
            pred_depth, confidence, output_dict = self.meta_arch.inference(
                {"input": image}
            )

        pred_depth = pred_depth.squeeze(1)
        pred_normal = output_dict['prediction_normal'][:, :3, :, :] # only available for Metric3Dv2 i.e., ViT models
        normal_confidence = output_dict['prediction_normal'][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details

        return pred_depth, pred_normal, normal_confidence


def patch_model(model):

    def interpolate_pos_encoding(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        # Comment out this code (so we always interpolate)
        # if npatch == N and w == h:
        #     return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset

        if torch.jit.is_tracing():
          sqrt_N = N ** 0.5
          patch_pos_embed = nn.functional.interpolate(
              patch_pos_embed.reshape(1, (sqrt_N).to(torch.int64), (sqrt_N).to(torch.int64), dim).permute(0, 3, 1, 2),
              size=(w0, h0),
              mode="bicubic",
              antialias=self.interpolate_antialias,
          )
        else:
          sqrt_N = math.sqrt(N)
          sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
          patch_pos_embed = nn.functional.interpolate(
              patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
              scale_factor=(sx, sy),
              mode="bicubic",
              antialias=self.interpolate_antialias,
          )

        assert int(w0) == patch_pos_embed.shape[-2]
        assert int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

    model.depth_model.encoder.interpolate_pos_encoding = (
        interpolate_pos_encoding.__get__(
            model.depth_model.encoder, model.depth_model.encoder.__class__
        )
    )

    def get_bins(self, bins_num):
        depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num)
        depth_bins_vec = torch.exp(depth_bins_vec)
        return depth_bins_vec

    model.depth_model.decoder.get_bins = (
        get_bins.__get__(
            model.depth_model.decoder, model.depth_model.decoder.__class__
        )
    )

    return model

# Load model
model_name = "metric3d_vit_small" # or "metric3d_vit_large" or "metric3d_vit_giant2"
model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True)
model.eval()

# Patch model so we can export to ONNX
model = patch_model(model)
export_model = Metric3DExportModel(model)
export_model.eval()

# Export the model
dummy_image = torch.randn([2, 3, 280, 420])
onnx_output = f"{model_name}.onnx"
torch.onnx.export(
    export_model,
    (dummy_image, ),
    onnx_output,
    input_names=["pixel_values"],
    output_names=["predicted_depth", "predicted_normal", "normal_confidence"],
    opset_version=11,

    dynamic_axes= {
      "pixel_values": {0: "batch_size", 2: "height", 3: "width"},
      "predicted_depth": {0: "batch_size", 1: "height", 2: "width"},
      "predicted_normal": {0: "batch_size", 2: "height", 3: "width"},
      "normal_confidence": {0: "batch_size", 1: "height", 2: "width"},
    }
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions