Following snippet highlights the issue:
int main() {
{
auto a = std::make_shared<Value>(1);
auto b = std::make_shared<Value>(2);
auto c = a + b;
}
// a, b and c are not destroyed ...
}
Issue seems to be with the captured variables of _backward lambda:
std::shared_ptr<Value> Value::operator+(const std::shared_ptr<Value>& other) {
auto out_prev = std::unordered_set<std::shared_ptr<Value>>{shared_from_this(), other};
auto out = std::make_shared<Value>(data + other->data, out_prev, "+");
out->_backward = [this, other, out] { // We're capturing out, hence the lambda shares ownership of out!
grad += out->grad;
other->grad += out->grad;
};
return out;
}
This leads to a situation where the lambda inside out shares ownership of out. So, _backward is not destroyed until until out is destroyed and out is not destroyed until _backward is destroyed, hence nothing is destroyed.
To fix, the lambda should only take weak ownership.
std::shared_ptr<Value> Value::operator+(const std::shared_ptr<Value>& other) {
auto out_prev = std::unordered_set<std::shared_ptr<Value>>{shared_from_this(), other};
auto out = std::make_shared<Value>(data + other->data, out_prev, "+");
Value* weak_ref = out.get();
out->_backward = [this, other, weak_ref] {
grad += weak_ref->grad;
other->grad += weak_ref->grad;
};
return out;
}
Following snippet highlights the issue:
Issue seems to be with the captured variables of
_backwardlambda:This leads to a situation where the lambda inside
outshares ownership ofout. So,_backwardis not destroyed until untiloutis destroyed andoutis not destroyed until_backwardis destroyed, hence nothing is destroyed.To fix, the lambda should only take weak ownership.