I'm opening this issue and including a solution in case anyone else wants to run this code.
This repo doesn't have a requirements.txt, so I installed the latest versions of all packages. This mostly worked, but the API for torchdiffeq has significantly changed since the ExNODE repo was created. When I ran the code, I got the following error:
Traceback (most recent call last):
File "/home/matt/code/school/ExNODE/classification/train.py", line 119, in <module>
main()
File "/home/matt/code/school/ExNODE/classification/train.py", line 111, in main
metric_log = model.fit() # training and return the log
File "/home/matt/code/school/ExNODE/classification/classifier.py", line 282, in fit
train_acc = self.train()
File "/home/matt/code/school/ExNODE/classification/classifier.py", line 314, in train
logits = self.logits(x)
File "/home/matt/code/school/ExNODE/classification/classifier.py", line 254, in logits
logits_x = self.model(x)
File "/home/matt/anaconda3/envs/exnode/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1186, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/code/school/ExNODE/classification/classifier.py", line 231, in forward
x = self.model(x)
File "/home/matt/anaconda3/envs/exnode/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1186, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/anaconda3/envs/exnode/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
File "/home/matt/anaconda3/envs/exnode/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1186, in _call_impl
return forward_call(*input, **kwargs)
File "/home/matt/code/school/ExNODE/classification/classifier.py", line 135, in forward
out = odeint_normal(self.odefunc, x, self.integration_time, self.rtol, self.atol, self.solver)
TypeError: odeint() takes 3 positional arguments but 6 were given
Note that the imports for classification/classifier.py includes from torchdiffeq import odeint_adjoint as odeint. I did some digging on the torchdiffeq repo, and it looks like this interface (6 positional arguments) for odeint_adjoint was last available in 2020. I found code in this commit that has the original 6-argument interface. To adapt ExNODE to the new interface, replace the code for the ODEBlock.forward function with the following:
def forward(self, x):
self.integration_time = self.integration_time.to(x)
if self.solver != 'dopri5':
out = odeint_normal(
func=self.odefunc,
y0=x,
t=self.integration_time,
rtol=self.rtol,
atol=self.atol,
method=self.solver
)
else:
out = odeint(
func=self.odefunc,
y0=x,
t=self.integration_time,
rtol=self.rtol,
atol=self.atol,
method=self.solver
)
return out[-1]
As far as I can tell, this appears to work, and I get results comparable to the paper:
14:03:56 Epoch: 119 Train: 0.9486/0.9506 Test: 0.8688/0.8932 for deepset.

I'm opening this issue and including a solution in case anyone else wants to run this code.
This repo doesn't have a
requirements.txt, so I installed the latest versions of all packages. This mostly worked, but the API fortorchdiffeqhas significantly changed since theExNODErepo was created. When I ran the code, I got the following error:Note that the imports for
classification/classifier.pyincludesfrom torchdiffeq import odeint_adjoint as odeint. I did some digging on thetorchdiffeqrepo, and it looks like this interface (6 positional arguments) forodeint_adjointwas last available in 2020. I found code in this commit that has the original 6-argument interface. To adaptExNODEto the new interface, replace the code for theODEBlock.forwardfunction with the following:As far as I can tell, this appears to work, and I get results comparable to the paper:

14:03:56 Epoch: 119 Train: 0.9486/0.9506 Test: 0.8688/0.8932for deepset.