-
Notifications
You must be signed in to change notification settings - Fork 12
Add get_training_input to storage_utils #438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
mkolodner-sc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Kyle! Did a pass here and left some comments/questions
|
|
||
|
|
||
| def get_training_input( | ||
| split: Union[Literal["train", "val", "test"], str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this sufficient? What if we want "all" nodes (i.e. dataset.node_ids)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That wouldn't be "training"input" would it? We can add another function to do that in the future (get_all_nodes?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be for the random negative loader for link prediction training, which would need some dataset.node_ids or equivalent.
We can add another function to do that in the future (get_all_nodes?)
We wouldn't need a whole different function, we could just specify some 'all' split and if its that we use _dataset.node_ids.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmmm, I do think that the tuple[Tensor, Tensor, Tensor | None] is important for the ABLP.
I guess I could rename this to get_ablp_input or something? Would that ameliorate your concerns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline, renaming will be fine here for this function, and in a follow-up we will refactor the get_node_ids_on_rank utility so that it can be used to split, making it extendable for the SNC use case, and can be called in this function to reduce the duplicity. Can we add a TODO here in the meantime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure added todos :)
python/tests/unit/distributed/graph_store/storage_utils_test.py
Outdated
Show resolved
Hide resolved
python/tests/unit/distributed/graph_store/storage_utils_test.py
Outdated
Show resolved
Hide resolved
|
|
||
|
|
||
| def get_training_input( | ||
| split: Union[Literal["train", "val", "test"], str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be for the random negative loader for link prediction training, which would need some dataset.node_ids or equivalent.
We can add another function to do that in the future (get_all_nodes?)
We wouldn't need a whole different function, we could just specify some 'all' split and if its that we use _dataset.node_ids.
| f"Anchor nodes must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(anchors)}" | ||
| ) | ||
|
|
||
| anchors_for_rank = shard_nodes_by_process(anchor_nodes, rank, world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember we had a discussion at some point of potentially moving this up to be user-facing instead of done under-the-hood by a GiGL utility/class. Would be curious to get your thoughts on this decision here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmmm, we did and I think we agreed we should do less of this - though we still do it in GS mode.
I went with this approach in GS mode so we minimize network traffic.
Since this is how we are currently doing this for RemoteDistDataset and friends, should we keep this approach for now and re-visit in the future, for all node fetching?
Adding some get_for_rank bool flag is probably sufficient here? WDYT about that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine either way if you'd prefer to do this now or a follow-up. If in follow-up, let's just make sure we add a TODO here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same qq.
My understanding was that there was some agreement where the client should have complicated logic of knowing what data to fetch from where.
Now whether the user prompts the client code to fetch exactly what they want, or uses this sort of utility on the client to implicitly fetch some split of nodes, the logic lives inside the client and the server is dumb and just fetches the data requested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see storage layer as a sort of a "db" specific to graph sampling.
If complex code to decide what data to provide to what client brings forward some smells, speicifcally I think it:
- Breaks determinism (The operations are implicit, and outside of control from client)
- Makes queries non-portable i.e. we are stuck with defining queries inside of the server vs client controlled
- Couples the query patterns to runtime topology; which may change in the future and the coupling will induce extra eng effort int the future to circumvent
- Potentially break ability to do replica/retry strategies (if needed in future) as the client doesn't know what data to expect and alignment between the two will be difficult unless client hosts logic.
If we do decide to go for this route I think we should prove out that it is for some sane reason i.e. actually proving out the network traffic argument. Would love elaboration here.
python/tests/unit/distributed/graph_store/storage_utils_test.py
Outdated
Show resolved
Hide resolved
python/tests/unit/distributed/graph_store/storage_utils_test.py
Outdated
Show resolved
Hide resolved
mkolodner-sc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Kyle for the work!
| ) | ||
|
|
||
|
|
||
| def destroy_test_process_group() -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Lazy abstraction to a function, esp cause only used in one place.
| The split ratios are calculated as: | ||
| - num_val = len(val_user_ids) / total_users | ||
| - num_test = len(test_user_ids) / total_users |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little confused reading this; isn't the user already passing in val_user_ids? why do we need to recompute the val split if one is being provided?
| Creates a dataset with: | ||
| - USER nodes: [0, 1, 2, 3, 4] | ||
| - STORY nodes: [0, 1, 2, 3, 4] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mega nit: STORY --> ITEM
More generic term for OSS
|
|
||
| # Set up edge partition books and edge indices | ||
| edge_partition_book = { | ||
| _USER_TO_STORY: torch.zeros(5, dtype=torch.int64), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
magic number 5.
Is this length of edges?
Similar magic numbers below re edge index.
| Creates a dataset with: | ||
| - USER nodes: [0, 1, 2, 3, 4] | ||
| - STORY nodes: [0, 1, 2, 3, 4] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make these arguments? i.e. user_node_ids, item_node_ids, ?
Or atleast default arguments/constants that can be referenced below to make the function more modular?
| f"Anchor nodes must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(anchors)}" | ||
| ) | ||
|
|
||
| anchors_for_rank = shard_nodes_by_process(anchor_nodes, rank, world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same qq.
My understanding was that there was some agreement where the client should have complicated logic of knowing what data to fetch from where.
Now whether the user prompts the client code to fetch exactly what they want, or uses this sort of utility on the client to implicitly fetch some split of nodes, the logic lives inside the client and the server is dumb and just fetches the data requested.
| f"Anchor nodes must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(anchors)}" | ||
| ) | ||
|
|
||
| anchors_for_rank = shard_nodes_by_process(anchor_nodes, rank, world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see storage layer as a sort of a "db" specific to graph sampling.
If complex code to decide what data to provide to what client brings forward some smells, speicifcally I think it:
- Breaks determinism (The operations are implicit, and outside of control from client)
- Makes queries non-portable i.e. we are stuck with defining queries inside of the server vs client controlled
- Couples the query patterns to runtime topology; which may change in the future and the coupling will induce extra eng effort int the future to circumvent
- Potentially break ability to do replica/retry strategies (if needed in future) as the client doesn't know what data to expect and alignment between the two will be difficult unless client hosts logic.
If we do decide to go for this route I think we should prove out that it is for some sane reason i.e. actually proving out the network traffic argument. Would love elaboration here.
Scope of work done
Add server-side util so we can remotely fetch the training input.
Since this is kind of a big PR not adding the client-side equivalent in this one :P
Again this is server-side code, and it's really meant to be called by users.
Where is the documentation for this feature?: N/A
Did you add automated tests or write a test plan?
Updated Changelog.md? NO
Ready for code review?: NO