Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating memory fails for datasets that are not bipartite #29

Open
daniel-gomm opened this issue Nov 20, 2023 · 1 comment
Open

Updating memory fails for datasets that are not bipartite #29

daniel-gomm opened this issue Nov 20, 2023 · 1 comment

Comments

@daniel-gomm
Copy link

daniel-gomm commented Nov 20, 2023

Hi,

If I am not mistaken, there seems to be a bug when using the model on a Unipartite dataset when updating the memory at the end of each batch memory_update_at_start=False.

Running the model like this incorrectly triggers the AssertionError: Trying to update to time in the past of the memory_updater module. This is due to lines 185-186 in tgn.py.

def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                  edge_idxs, n_neighbors=20):
    ...
    if self.use_memory:
      if self.memory_update_at_start:
        # Update memory for all nodes with messages stored in previous batches
        memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
                                                      self.memory.messages)
      else:
        memory = self.memory.get_memory(list(range(self.n_nodes)))
        last_update = self.memory.last_update

      ...

    if self.use_memory:
      if self.memory_update_at_start:
        # Persist the updates to the memory only for sources and destinations (since now we have
        # new messages for them)
        self.update_memory(positives, self.memory.messages)

        assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
          "Something wrong in how the memory was updated"

        # Remove messages for the positives since we have already updated the memory using them
        self.memory.clear_messages(positives)

      unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes, source_node_embedding, destination_nodes, destination_node_embedding, edge_times, edge_idxs)
      unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes, destination_node_embedding, source_nodes, source_node_embedding, edge_times, edge_idxs)
      if self.memory_update_at_start:
        self.memory.store_raw_messages(unique_sources, source_id_to_messages)
        self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
      else:
        self.update_memory(unique_sources, source_id_to_messages)                  <-- 185
        self.update_memory(unique_destinations, destination_id_to_messages)        <-- 186

     ...

    return source_node_embedding, destination_node_embedding, negative_node_embedding

When the source_nodes and destination_nodes contain non-overlapping node ids this is not a problem. However, when using a unipartite graph, the same node id can be in the source_nodes and the destination_nodes, which causes the described issue if this node id is associated with a later timestamp on the source node side, then the target node side.

This problem can be resolved by replacing:

      if self.memory_update_at_start:
        self.memory.store_raw_messages(unique_sources, source_id_to_messages)
        self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
      else:
        self.update_memory(unique_sources, source_id_to_messages)
        self.update_memory(unique_destinations, destination_id_to_messages)

with:

            self.memory.store_raw_messages(unique_sources, source_id_to_messages)
            self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)

            if not self.memory_update_at_start:
                unique_node_ids = np.unique(np.concatenate((unique_sources, unique_destinations)))
                self.update_memory(unique_node_ids,
                             self.memory.messages)
                self.memory.clear_messages(unique_node_ids)

Edit: Found an issue in the fix initially proposed and updated matching the pull request

@daniel-gomm
Copy link
Author

I have opened a pull request #30 with the changes detailed above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant