Skip to content

RuntimeError: Expected tensor for argument #1 'input' to have the same type as tensor for argument #2 'weight'; but type torch.cuda.FloatTensor does not equal torch.cuda.HalfTensor (while checking arguments for cudnn_batch_norm) #42

@wzr0108

Description

@wzr0108

Got this error when running torch2trt.py

(PersonaLive) wzr@DESKTOP-H41D2AV:~/PersonaLive$ uv run torch2trt.py 
Uninstalled 1 package in 0.66ms
Installed 1 package in 23ms
/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/transformers/utils/generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.
  _torch_pytree._register_pytree_node(
/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/transformers/utils/generic.py:309: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.
  _torch_pytree._register_pytree_node(
Warning: Accessing the `severity` property of G_LOGGER is deprecated and will be removed in v0.50.0. Use `module_severity` instead
['pred_video', 'latents', 'pose_cond_fea_out', 'motion_hidden_states_out', 'motion_out', 'latent_first']
开始导出 ONNX 模型到: ./pretrained_weights/onnx/unet/unet.onnx ...
/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/diffusers/models/embeddings.py:175: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.height != height or self.width != width:
/home/wzr/PersonaLive/src/models/motion_encoder/FAN_feature_extractor.py:100: UserWarning: `nn.functional.upsample` is deprecated. Use `nn.functional.interpolate` instead.
  up2 = F.upsample(low3, size=rescale_size, mode='bilinear')
Traceback (most recent call last):
  File "/home/wzr/PersonaLive/torch2trt.py", line 93, in <module>
    export_onnx(
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/wzr/PersonaLive/src/modeling/onnx_export.py", line 59, in export_onnx
    torch.onnx.utils.export(
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 522, in export
    _export(
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1457, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1080, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 964, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 871, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 1504, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 138, in forward
    graph, _out = torch._C._create_graph_by_tracing(
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 129, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/wzr/PersonaLive/src/modeling/framed_models.py", line 28, in forward
    new_motion_hidden_states = self.motion_encoder(motion)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/wzr/PersonaLive/src/models/motion_encoder/encoder.py", line 37, in forward
    latent = self.model(rearrange(x, "b c f h w -> (b f) c h w"))
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/wzr/PersonaLive/src/models/motion_encoder/FAN_feature_extractor.py", line 314, in forward
    hg = self._modules['m' + str(i)](previous)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/wzr/PersonaLive/src/models/motion_encoder/FAN_feature_extractor.py", line 105, in forward
    return self._forward(self.depth, x)
  File "/home/wzr/PersonaLive/src/models/motion_encoder/FAN_feature_extractor.py", line 91, in _forward
    low2 = self._forward(level - 1, low1)
  File "/home/wzr/PersonaLive/src/models/motion_encoder/FAN_feature_extractor.py", line 91, in _forward
    low2 = self._forward(level - 1, low1)
  File "/home/wzr/PersonaLive/src/models/motion_encoder/FAN_feature_extractor.py", line 97, in _forward
    low3 = self._modules['b3_' + str(level)](low3)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/wzr/PersonaLive/src/models/motion_encoder/FAN_feature_extractor.py", line 37, in forward
    out1 = self.bn1(x)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1763, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 193, in forward
    return F.batch_norm(
  File "/home/wzr/PersonaLive/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2817, in batch_norm
    return torch.batch_norm(
RuntimeError: Expected tensor for argument #1 'input' to have the same type as tensor for argument #2 'weight'; but type torch.cuda.FloatTensor does not equal torch.cuda.HalfTensor (while checking arguments for cudnn_batch_norm)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions