Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions configs/postprocessors/ras.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
postprocessor:
name: ras
APS_mode: False
postprocessor_args: {}
postprocessor_sweep: {}
6 changes: 5 additions & 1 deletion openood/evaluation_api/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openood.networks.ash_net import ASHNet
from openood.networks.react_net import ReactNet
from openood.networks.scale_net import ScaleNet
from openood.networks.ras_net import RASNet
from openood.networks.adascale_net import AdaScaleANet, AdaScaleLNet

from .datasets import DATA_INFO, data_setup, get_id_ood_dataloader
Expand Down Expand Up @@ -90,7 +91,8 @@ def __init__(
# set up config root
if config_root is None:
filepath = os.path.dirname(os.path.abspath(__file__))
config_root = os.path.join('/', *filepath.split('/')[:-2], 'configs')
config_root = os.path.join('/',
*filepath.split('/')[:-2], 'configs')

# get postprocessor
if postprocessor is None:
Expand Down Expand Up @@ -121,6 +123,8 @@ def __init__(
net = AdaScaleANet(net)
elif postprocessor_name == 'adascale_l':
net = AdaScaleLNet(net)
elif postprocessor_name == 'ras':
net = RASNet(net)

# postprocessor setup
postprocessor.setup(net, dataloader_dict['id'], dataloader_dict['ood'])
Expand Down
5 changes: 3 additions & 2 deletions openood/evaluation_api/postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
RMDSPostprocessor, SHEPostprocessor, CIDERPostprocessor, NPOSPostprocessor,
GENPostprocessor, NNGuidePostprocessor, RelationPostprocessor,
T2FNormPostprocessor, ReweightOODPostprocessor, fDBDPostprocessor,
AdaScalePostprocessor, IODINPostprocessor, NCIPostprocessor,CFOODPostprocessor,
VRAPostprocessor, GrOODPostprocessor)
AdaScalePostprocessor, IODINPostprocessor, NCIPostprocessor,
CFOODPostprocessor, VRAPostprocessor, GrOODPostprocessor, RASPostprocessor)
from openood.utils.config import Config, merge_configs

postprocessors = {
Expand Down Expand Up @@ -73,6 +73,7 @@
'grood': GrOODPostprocessor,
'vra': VRAPostprocessor,
'cfood': CFOODPostprocessor,
'ras': RASPostprocessor,
}

link_prefix = 'https://raw.githubusercontent.com/Jingkang50/OpenOOD/main/configs/postprocessors/'
Expand Down
33 changes: 33 additions & 0 deletions openood/networks/ras_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import torch.nn as nn


class RASNet(nn.Module):
def __init__(self, backbone):
super(RASNet, self).__init__()
self.backbone = backbone
self.mean_curve = None

def forward(self, x, return_feature=False, return_feature_list=False):
try:
return self.backbone(x, return_feature, return_feature_list)
except TypeError:
return self.backbone(x, return_feature)

def forward_shift(self, x):
_, feature = self.backbone(x, return_feature=True)
feature = feature.view(feature.size(0), -1)

sorted_vals, idx = torch.sort(feature, dim=1)
mc = self.mean_curve.to(feature.device).expand_as(sorted_vals)
shifted = torch.empty_like(feature).scatter_(1, idx, mc)

logits_cls = self.backbone.get_fc_layer()(shifted)
return logits_cls

def set_mean_curve(self, curve):
self.mean_curve = curve

def get_fc(self):
fc = self.backbone.fc
return fc.weight.cpu().detach().numpy(), fc.bias.cpu().detach().numpy()
5 changes: 5 additions & 0 deletions openood/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .openmax_net import OpenMax
from .patchcore_net import PatchcoreNet
from .projection_net import ProjectionNet
from .ras_net import RASNet
from .react_net import ReactNet
from .resnet18_32x32 import ResNet18_32x32
from .resnet18_64x64 import ResNet18_64x64
Expand Down Expand Up @@ -172,6 +173,10 @@ def get_network(network_config):
backbone = get_network(network_config.backbone)
net = ReactNet(backbone)

elif network_config.name == 'ras_net':
backbone = get_network(network_config.backbone)
net = RASNet(backbone)

elif network_config.name == 'csi_net':
# don't wrap ddp here cuz we need to modify
# backbone
Expand Down
2 changes: 1 addition & 1 deletion openood/postprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .opengan_postprocessor import OpenGanPostprocessor
from .openmax_postprocessor import OpenMax
from .patchcore_postprocessor import PatchcorePostprocessor
from .ras_postprocessor import RASPostprocessor
from .rd4ad_postprocessor import Rd4adPostprocessor
from .react_postprocessor import ReactPostprocessor
from .rmds_postprocessor import RMDSPostprocessor
Expand All @@ -50,4 +51,3 @@
from .grood import GrOODPostprocessor
from .vra_postprocessor import VRAPostprocessor
from .cfood_postprocessor import CFOODPostprocessor

45 changes: 45 additions & 0 deletions openood/postprocessors/ras_postprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any

import torch
import torch.nn as nn
from tqdm import tqdm

from .base_postprocessor import BasePostprocessor


class RASPostprocessor(BasePostprocessor):
def __init__(self, config):
super(RASPostprocessor, self).__init__(config)
self.setup_flag = False

def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict):
if not self.setup_flag:
activation_log = []
net.eval()
with torch.no_grad():
loader = id_loader_dict['train']
for batch in tqdm(loader,
desc='Setup: ',
position=0,
leave=True):
data = batch['data'].cuda()
data = data.float()

_, feature = net(data, return_feature=True)
activation_log.append(feature.data.cpu())

activation_log = torch.cat(activation_log, dim=0)
sorted_vals, _ = torch.sort(activation_log, dim=1)
mean_curve = sorted_vals.mean(dim=0).cuda()

net.set_mean_curve(mean_curve)
self.mean_curve = mean_curve
self.setup_flag = True

@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
output = net.forward_shift(data)
score = torch.softmax(output, dim=1)
_, pred = torch.max(score, dim=1)
conf = torch.logsumexp(output.data.cpu(), dim=1)
return pred, conf
2 changes: 2 additions & 0 deletions openood/postprocessors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .relation_postprocessor import RelationPostprocessor
from .grood import GrOODPostprocessor
from .vra_postprocessor import VRAPostprocessor
from .ras_postprocessor import RASPostprocessor


def get_postprocessor(config: Config):
Expand Down Expand Up @@ -94,6 +95,7 @@ def get_postprocessor(config: Config):
't2fnorm': T2FNormPostprocessor,
'grood': GrOODPostprocessor,
'vra': VRAPostprocessor,
'ras': RASPostprocessor,
}

return postprocessors[config.postprocessor.name](config)