Skip to content

Pytorch triple backward#200

Open
asglover wants to merge 10 commits into
mainfrom
pytorch-triple-backward
Open

Pytorch triple backward#200
asglover wants to merge 10 commits into
mainfrom
pytorch-triple-backward

Conversation

@asglover
Copy link
Copy Markdown
Collaborator

@asglover asglover commented Jun 2, 2026

Adding triple backward support for higher order training to pytorch

@asglover asglover added the ci-ready Triggers CI checks for a pull request label Jun 2, 2026
@asglover
Copy link
Copy Markdown
Collaborator Author

asglover commented Jun 2, 2026

I'm going to run the full test suite tomorrow.
The changes are mostly about triple backwards, although my model was convinced that there was an error in the stream testing, and under closer inspection it looks like there was.
I'll promote this from Draft to regular PR when it's ready for review.

Copy link
Copy Markdown
Member

@vbharadwaj-bk vbharadwaj-bk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

95% looks good. I think the stream_test.py modifications are redundant, since each of the custom ops have already been tested. But otherwise good.

):
assert self.torch_op

in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO-someday: I wonder if we can combine all of these derivative functions into one to compact this file.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you just mean, turn them from .to("cuda") to (... , dtype=cuda).

If so, I should definitely do this. It's considered good practice to. technically this does something funny which is create the tensor, then move it to the device, so the autograd actually does the opposite, moves it off device at the end. When it's supposed to just be on device I ought to use dtype=cuda. I'll clean this up. good point.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread tests/batch_test.py Outdated
@pytest.fixture(scope="class")
def problem(self, dtype, with_jax):
if with_jax:
pytest.skip("N/A for JAX")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO-someday: we could expand this test to include JAX. But not in this commit.

Comment thread tests/stream_test.py Outdated
return (X, Y, W, edge_index[0], edge_index[1])


@pytest.fixture
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I don't think we should have any modifications to stream_test.py. Because triple_backward is a composition of existing ops that all work fine with streams, I see no reason why their composition shouldn't pass stream tests. We need to test anything that's implemented as a custom op to make sure that the stream information is lowered correctly onto the kernel, but then any composition of those operators should be ok. Let's shrink the diff here.

Copy link
Copy Markdown
Collaborator Author

@asglover asglover Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that any change was made is because the stream tests say (in the name) they are testing the deterministic mode but they are not (deterministic=false). This was masking some strange behavior.

I can move this to a different PR or just ignore it. It is off topic for this PR. I'll remove it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to a different pr

Comment thread tests/conv_test.py Outdated
self.check_result(result, fieldname)


class TestTripleBackwardConvDirectOps:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me, what is the purpose of these DirectOps tests?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed them, they were a vestigial test from the first pass at tests, before I moved them to be part of the parameterized tests

@asglover asglover removed the ci-ready Triggers CI checks for a pull request label Jun 5, 2026
@vbharadwaj-bk vbharadwaj-bk added the ci-ready Triggers CI checks for a pull request label Jun 5, 2026
Copy link
Copy Markdown
Member

@vbharadwaj-bk vbharadwaj-bk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, merge when ci finishes 🏁

@vbharadwaj-bk vbharadwaj-bk marked this pull request as ready for review June 5, 2026 05:29
@vbharadwaj-bk vbharadwaj-bk added ci-ready Triggers CI checks for a pull request and removed ci-ready Triggers CI checks for a pull request labels Jun 5, 2026
@asglover
Copy link
Copy Markdown
Collaborator Author

asglover commented Jun 5, 2026

Please don't merge yet, I saw some interesting failures in the test suite that I'd like to investigate. I think they are not related to this PR, but I'd like to confirm before proceeding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-ready Triggers CI checks for a pull request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants