Pytorch triple backward#200
Conversation
|
I'm going to run the full test suite tomorrow. |
vbharadwaj-bk
left a comment
There was a problem hiding this comment.
95% looks good. I think the stream_test.py modifications are redundant, since each of the custom ops have already been tested. But otherwise good.
| ): | ||
| assert self.torch_op | ||
|
|
||
| in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) |
There was a problem hiding this comment.
TODO-someday: I wonder if we can combine all of these derivative functions into one to compact this file.
There was a problem hiding this comment.
Do you just mean, turn them from .to("cuda") to (... , dtype=cuda).
If so, I should definitely do this. It's considered good practice to. technically this does something funny which is create the tensor, then move it to the device, so the autograd actually does the opposite, moves it off device at the end. When it's supposed to just be on device I ought to use dtype=cuda. I'll clean this up. good point.
| @pytest.fixture(scope="class") | ||
| def problem(self, dtype, with_jax): | ||
| if with_jax: | ||
| pytest.skip("N/A for JAX") |
There was a problem hiding this comment.
TODO-someday: we could expand this test to include JAX. But not in this commit.
| return (X, Y, W, edge_index[0], edge_index[1]) | ||
|
|
||
|
|
||
| @pytest.fixture |
There was a problem hiding this comment.
Hmm. I don't think we should have any modifications to stream_test.py. Because triple_backward is a composition of existing ops that all work fine with streams, I see no reason why their composition shouldn't pass stream tests. We need to test anything that's implemented as a custom op to make sure that the stream information is lowered correctly onto the kernel, but then any composition of those operators should be ok. Let's shrink the diff here.
There was a problem hiding this comment.
The reason that any change was made is because the stream tests say (in the name) they are testing the deterministic mode but they are not (deterministic=false). This was masking some strange behavior.
I can move this to a different PR or just ignore it. It is off topic for this PR. I'll remove it.
There was a problem hiding this comment.
moved to a different pr
| self.check_result(result, fieldname) | ||
|
|
||
|
|
||
| class TestTripleBackwardConvDirectOps: |
There was a problem hiding this comment.
Remind me, what is the purpose of these DirectOps tests?
There was a problem hiding this comment.
I removed them, they were a vestigial test from the first pass at tests, before I moved them to be part of the parameterized tests
vbharadwaj-bk
left a comment
There was a problem hiding this comment.
nice, merge when ci finishes 🏁
|
Please don't merge yet, I saw some interesting failures in the test suite that I'd like to investigate. I think they are not related to this PR, but I'd like to confirm before proceeding. |
Adding triple backward support for higher order training to pytorch