Skip to content

TypeError: odeint() takes 3 positional arguments but 6 were given #2

@mattrmd

Description

@mattrmd

I'm opening this issue and including a solution in case anyone else wants to run this code.

This repo doesn't have a requirements.txt, so I installed the latest versions of all packages. This mostly worked, but the API for torchdiffeq has significantly changed since the ExNODE repo was created. When I ran the code, I got the following error:

Traceback (most recent call last):
  File "/home/matt/code/school/ExNODE/classification/train.py", line 119, in <module>
    main()
  File "/home/matt/code/school/ExNODE/classification/train.py", line 111, in main
    metric_log = model.fit() # training and return the log
  File "/home/matt/code/school/ExNODE/classification/classifier.py", line 282, in fit
    train_acc = self.train()
  File "/home/matt/code/school/ExNODE/classification/classifier.py", line 314, in train
    logits = self.logits(x)
  File "/home/matt/code/school/ExNODE/classification/classifier.py", line 254, in logits
    logits_x = self.model(x)
  File "/home/matt/anaconda3/envs/exnode/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1186, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/code/school/ExNODE/classification/classifier.py", line 231, in forward
    x = self.model(x)
  File "/home/matt/anaconda3/envs/exnode/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1186, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/anaconda3/envs/exnode/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/matt/anaconda3/envs/exnode/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1186, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/matt/code/school/ExNODE/classification/classifier.py", line 135, in forward
    out = odeint_normal(self.odefunc, x, self.integration_time, self.rtol, self.atol, self.solver)
TypeError: odeint() takes 3 positional arguments but 6 were given

Note that the imports for classification/classifier.py includes from torchdiffeq import odeint_adjoint as odeint. I did some digging on the torchdiffeq repo, and it looks like this interface (6 positional arguments) for odeint_adjoint was last available in 2020. I found code in this commit that has the original 6-argument interface. To adapt ExNODE to the new interface, replace the code for the ODEBlock.forward function with the following:

def forward(self, x):
  self.integration_time = self.integration_time.to(x)
  if self.solver != 'dopri5':
    out = odeint_normal(
      func=self.odefunc,
      y0=x,
      t=self.integration_time,
      rtol=self.rtol,
      atol=self.atol,
      method=self.solver
    )
  else:
    out = odeint(
      func=self.odefunc,
      y0=x,
      t=self.integration_time,
      rtol=self.rtol,
      atol=self.atol,
      method=self.solver
    )
  return out[-1]

As far as I can tell, this appears to work, and I get results comparable to the paper:
14:03:56 Epoch: 119 Train: 0.9486/0.9506 Test: 0.8688/0.8932 for deepset.
image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions