Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_pt_flax_equivalence and test_encoder_decoder_model_standalone fail running on device (cuda or xpu) #33517

Open
dvrogozh opened this issue Sep 16, 2024 · 0 comments · May be fixed by #33485
Labels

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Sep 16, 2024

With:

Issue seen on NVidia A10 and Intel PVC.

test_pt_flax_equivalence and test_encoder_decoder_model_standalone are failing across multiple models due to missing models or tensors placements on devices. Specifically, there are 3 types of issues:

  1. Model was not moved to device (model.to(cuda) is missing)
  2. Input was not moved to device (input.to(cuda) is missing)
  3. torch.Tensor.numpy() called with tensor being on device (should first be moved to CPU according to https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html)

Proposed fix:

CC: @sanchit-gandhi, @amyeroberts

See the following log for repro cmdline and list of errors (log running on NVidia A10, for XPU log will be similar):

$ python3 -m pytest --tb=short \
tests/models/informer/test_modeling_informer.py::InformerModelTest::test_encoder_decoder_model_standalone \
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py::FlaxGPT2EncoderDecoderModelTest::test_pt_flax_equivalence \
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py::FlaxBartEncoderDecoderModelTest::test_pt_flax_equivalence \
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py::FlaxBertEncoderDecoderModelTest::test_pt_flax_equivalence \
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py::ViTBertModelTest::test_pt_flax_equivalence \
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py::CLIPVisionBertModelTest::test_pt_flax_equivalence \
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py::FlaxWav2Vec2GPT2ModelTest::test_pt_flax_equivalence \
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py::FlaxWav2Vec2BartModelTest::test_pt_flax_equivalence \
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py::FlaxWav2Vec2BertModelTest::test_pt_flax_equivalence \
tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py::FlaxViTBertModelTest::test_pt_flax_equivalence tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py::FlaxCLIPVisionBertModelTest::test_pt_flax_equivalence \
tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py::FlaxViT2GPT2EncoderDecoderModelTest::test_pt_flax_equivalence
========================================================================================= test session starts =========================================================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/dvrogozh/git/huggingface/transformers
configfile: pyproject.toml
plugins: hypothesis-6.111.1, subtests-0.13.1, rich-0.1.1, dash-2.17.1, xdist-3.6.1, pspec-0.0.4, timeout-2.3.1
collected 12 items

tests/models/informer/test_modeling_informer.py F                                                                                                                                               [  8%]
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py FFF                                                                                                                          [ 33%]
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py FF                                                                                                              [ 50%]
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py FFF                                                                                                            [ 75%]
tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py FF                                                                                                         [ 91%]
tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py F                                                                                                              [100%]

============================================================================================== FAILURES ===============================================================================================
_______________________________________________________________________ InformerModelTest.test_encoder_decoder_model_standalone _______________________________________________________________________
tests/models/informer/test_modeling_informer.py:226: in test_encoder_decoder_model_standalone
    self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
tests/models/informer/test_modeling_informer.py:174: in check_encoder_decoder_model_standalone
    self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
E   RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument other in method wrapper_CUDA__equal)
______________________________________________________________________ FlaxGPT2EncoderDecoderModelTest.test_pt_flax_equivalence _______________________________________________________________________
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:413: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:344: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:303: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/encoder_decoder/modeling_encoder_decoder.py:597: in forward
    encoder_outputs = self.encoder(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/bert/modeling_bert.py:1077: in forward
    embedding_output = self.embeddings(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/bert/modeling_bert.py:210: in forward
    inputs_embeds = self.word_embeddings(input_ids)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/sparse.py:190: in forward
    return F.embedding(
../../pytorch/pytorch/torch/nn/functional.py:2551: in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
E   RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
______________________________________________________________________ FlaxBartEncoderDecoderModelTest.test_pt_flax_equivalence _______________________________________________________________________
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:413: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:344: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:303: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/encoder_decoder/modeling_encoder_decoder.py:597: in forward
    encoder_outputs = self.encoder(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/bert/modeling_bert.py:1077: in forward
    embedding_output = self.embeddings(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/bert/modeling_bert.py:210: in forward
    inputs_embeds = self.word_embeddings(input_ids)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/sparse.py:190: in forward
    return F.embedding(
../../pytorch/pytorch/torch/nn/functional.py:2551: in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
E   RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
---------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------
Config of the decoder: <class 'transformers.models.bart.modeling_bart.BartForCausalLM'> is overwritten by shared decoder config: BartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "d_model": 32,
  "decoder_attention_heads": 4,
  "decoder_ffn_dim": 4,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 2,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "encoder_attention_heads": 4,
  "encoder_ffn_dim": 4,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 2,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "initializer_range": 0.02,
  "is_decoder": true,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 32,
  "model_type": "bart",
  "num_hidden_layers": 2,
  "pad_token_id": 1,
  "scale_embedding": false,
  "transformers_version": "4.45.0.dev0",
  "use_cache": false,
  "vocab_size": 99
}

______________________________________________________________________ FlaxBertEncoderDecoderModelTest.test_pt_flax_equivalence _______________________________________________________________________
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:413: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:344: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py:303: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/encoder_decoder/modeling_encoder_decoder.py:597: in forward
    encoder_outputs = self.encoder(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/bert/modeling_bert.py:1077: in forward
    embedding_output = self.embeddings(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/bert/modeling_bert.py:210: in forward
    inputs_embeds = self.word_embeddings(input_ids)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/sparse.py:190: in forward
    return F.embedding(
../../pytorch/pytorch/torch/nn/functional.py:2551: in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
E   RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
______________________________________________________________________________ ViTBertModelTest.test_pt_flax_equivalence ______________________________________________________________________________
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py:266: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py:226: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py:182: in check_pt_flax_equivalence
    flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py:182: in <dictcomp>
    flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
E   TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
__________________________________________________________________________ CLIPVisionBertModelTest.test_pt_flax_equivalence ___________________________________________________________________________
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py:266: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py:226: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py:182: in check_pt_flax_equivalence
    flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py:182: in <dictcomp>
    flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
E   TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
_________________________________________________________________________ FlaxWav2Vec2GPT2ModelTest.test_pt_flax_equivalence __________________________________________________________________________
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:532: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:459: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:418: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py:501: in forward
    encoder_outputs = self.encoder(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:1809: in forward
    extract_features = self.feature_extractor(input_values)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:463: in forward
    hidden_states = conv_layer(hidden_states)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:332: in forward
    hidden_states = self.conv(hidden_states)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/conv.py:375: in forward
    return self._conv_forward(input, self.weight, self.bias)
../../pytorch/pytorch/torch/nn/modules/conv.py:370: in _conv_forward
    return F.conv1d(
E   RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
_________________________________________________________________________ FlaxWav2Vec2BartModelTest.test_pt_flax_equivalence __________________________________________________________________________
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:532: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:459: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:418: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py:501: in forward
    encoder_outputs = self.encoder(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:1809: in forward
    extract_features = self.feature_extractor(input_values)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:463: in forward
    hidden_states = conv_layer(hidden_states)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:332: in forward
    hidden_states = self.conv(hidden_states)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/conv.py:375: in forward
    return self._conv_forward(input, self.weight, self.bias)
../../pytorch/pytorch/torch/nn/modules/conv.py:370: in _conv_forward
    return F.conv1d(
E   RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
---------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------
Config of the decoder: <class 'transformers.models.bart.modeling_bart.BartForCausalLM'> is overwritten by shared decoder config: BartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "d_model": 24,
  "decoder_attention_heads": 4,
  "decoder_ffn_dim": 4,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 2,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "encoder_attention_heads": 4,
  "encoder_ffn_dim": 4,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 2,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "initializer_range": 0.02,
  "is_decoder": true,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 32,
  "model_type": "bart",
  "num_hidden_layers": 2,
  "pad_token_id": 1,
  "scale_embedding": false,
  "transformers_version": "4.45.0.dev0",
  "use_cache": false,
  "vocab_size": 99
}

_________________________________________________________________________ FlaxWav2Vec2BertModelTest.test_pt_flax_equivalence __________________________________________________________________________
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:532: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:459: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py:418: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py:501: in forward
    encoder_outputs = self.encoder(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:1809: in forward
    extract_features = self.feature_extractor(input_values)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:463: in forward
    hidden_states = conv_layer(hidden_states)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/wav2vec2/modeling_wav2vec2.py:332: in forward
    hidden_states = self.conv(hidden_states)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/conv.py:375: in forward
    return self._conv_forward(input, self.weight, self.bias)
../../pytorch/pytorch/torch/nn/modules/conv.py:370: in _conv_forward
    return F.conv1d(
E   RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
____________________________________________________________________________ FlaxViTBertModelTest.test_pt_flax_equivalence ____________________________________________________________________________
tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py:243: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py:207: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py:166: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py:358: in forward
    vision_outputs = self.vision_model(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vit/modeling_vit.py:619: in forward
    embedding_output = self.embeddings(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vit/modeling_vit.py:124: in forward
    embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vit/modeling_vit.py:183: in forward
    embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/conv.py:554: in forward
    return self._conv_forward(input, self.weight, self.bias)
../../pytorch/pytorch/torch/nn/modules/conv.py:549: in _conv_forward
    return F.conv2d(
E   RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
________________________________________________________________________ FlaxCLIPVisionBertModelTest.test_pt_flax_equivalence _________________________________________________________________________
tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py:243: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py:207: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py:166: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py:358: in forward
    vision_outputs = self.vision_model(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/clip/modeling_clip.py:1116: in forward
    return self.vision_model(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/clip/modeling_clip.py:1040: in forward
    hidden_states = self.embeddings(pixel_values)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/clip/modeling_clip.py:202: in forward
    patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/conv.py:554: in forward
    return self._conv_forward(input, self.weight, self.bias)
../../pytorch/pytorch/torch/nn/modules/conv.py:549: in _conv_forward
    return F.conv2d(
E   RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
____________________________________________________________________ FlaxViT2GPT2EncoderDecoderModelTest.test_pt_flax_equivalence _____________________________________________________________________
tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py:352: in test_pt_flax_equivalence
    self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py:288: in check_equivalence_pt_to_flax
    self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py:247: in check_pt_flax_equivalence
    pt_outputs = pt_model(**pt_inputs).to_tuple()
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py:587: in forward
    encoder_outputs = self.encoder(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vit/modeling_vit.py:619: in forward
    embedding_output = self.embeddings(
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vit/modeling_vit.py:124: in forward
    embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
src/transformers/models/vit/modeling_vit.py:183: in forward
    embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
../../pytorch/pytorch/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
../../pytorch/pytorch/torch/nn/modules/conv.py:554: in forward
    return self._conv_forward(input, self.weight, self.bias)
../../pytorch/pytorch/torch/nn/modules/conv.py:549: in _conv_forward
    return F.conv2d(
E   RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
========================================================================================== warnings summary ===========================================================================================
../../../pytorch.cuda/lib/python3.10/site-packages/tensorflow/__init__.py:30
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/tensorflow/__init__.py:30: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives
    import distutils as _distutils

src/transformers/deepspeed.py:24
  /home/dvrogozh/git/huggingface/transformers/src/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
    warnings.warn(

../../../pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:46
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:46: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
    v: jnp.DeviceArray,

../../../pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:48
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:48: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
    inputs: jnp.DeviceArray,

../../../pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:49
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:49: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
    targets: jnp.DeviceArray,

../../../pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:50
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:50: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
    ) -> jnp.DeviceArray:

../../../pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:72
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:72: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
    inputs: jnp.DeviceArray,

../../../pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:73
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:73: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
    targets: jnp.DeviceArray,

../../../pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:74
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:74: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
    ) -> jnp.DeviceArray:

../../../pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:97
  /home/dvrogozh/pytorch.cuda/lib/python3.10/site-packages/optax/_src/second_order.py:97: DeprecationWarning: jax.numpy.DeviceArray is deprecated. Use jax.Array.
    ) -> jnp.DeviceArray:

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================= short test summary info =======================================================================================
FAILED tests/models/informer/test_modeling_informer.py::InformerModelTest::test_encoder_decoder_model_standalone - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument other in method wrapper_CUDA__equal)
FAILED tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py::FlaxGPT2EncoderDecoderModelTest::test_pt_flax_equivalence - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
FAILED tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py::FlaxBartEncoderDecoderModelTest::test_pt_flax_equivalence - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
FAILED tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py::FlaxBertEncoderDecoderModelTest::test_pt_flax_equivalence - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
FAILED tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py::ViTBertModelTest::test_pt_flax_equivalence - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py::CLIPVisionBertModelTest::test_pt_flax_equivalence - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py::FlaxWav2Vec2GPT2ModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
FAILED tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py::FlaxWav2Vec2BartModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
FAILED tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py::FlaxWav2Vec2BertModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
FAILED tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py::FlaxViTBertModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
FAILED tests/models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py::FlaxCLIPVisionBertModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
FAILED tests/models/vision_encoder_decoder/test_modeling_flax_vision_encoder_decoder.py::FlaxViT2GPT2EncoderDecoderModelTest::test_pt_flax_equivalence - RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
================================================================================== 12 failed, 10 warnings in 19.19s ===================================================================================
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Sep 16, 2024
This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"

According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.

Fixes: huggingface#33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html
Signed-off-by: Dmitry Rogozhkin <[email protected]>
@dvrogozh dvrogozh linked a pull request Sep 16, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants