Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def extract_k_best_modalities_per_task(self):
for modality in self.modalities:
k_best_results, cached_data = (
self.optimization_results.get_k_best_results(
modality, self.k, task, self.scoring_metric
modality, task, self.scoring_metric
)
)
representations[task.model.name][modality.modality_id] = k_best_results
Expand Down
14 changes: 6 additions & 8 deletions src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _extract_k_best_representations(
for modality in self.modalities:
k_best_results, cached_data = (
unimodal_optimization_results.get_k_best_results(
modality, self.k, task, self.metric_name
modality, task, self.metric_name
)
)

Expand Down Expand Up @@ -359,26 +359,24 @@ def _evaluate_dag(self, dag: RepresentationDag, task: Task) -> "OptimizationResu
)
),
task,
enable_cache=False,
)

torch.cuda.empty_cache()

if fused_representation is None:
return None

final_representation = fused_representation[
list(fused_representation.keys())[-1]
]
if task.expected_dim == 1 and get_shape(final_representation.metadata) > 1:
if task.expected_dim == 1 and get_shape(fused_representation.metadata) > 1:
agg_operator = AggregatedRepresentation(Aggregation())
final_representation = agg_operator.transform(final_representation)
fused_representation = agg_operator.transform(fused_representation)

eval_start = time.time()
scores = task.run(final_representation.data)
scores = task.run(fused_representation.data)
eval_time = time.time() - eval_start

total_time = time.time() - start_time

del fused_representation
return OptimizationResult(
dag=dag,
train_score=scores[0].average_scores,
Expand Down
22 changes: 13 additions & 9 deletions src/main/python/systemds/scuro/drsearch/representation_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# -------------------------------------------------------------
import copy
from dataclasses import dataclass, field
from typing import List, Dict, Any
from typing import List, Dict, Union, Any, Hashable, Optional
from systemds.scuro.modality.modality import Modality
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.representation import (
Expand All @@ -34,9 +34,7 @@
DimensionalityReduction,
)
from systemds.scuro.utils.identifier import get_op_id, get_node_id

from collections import OrderedDict
from typing import Any, Hashable, Optional


class LRUCache:
Expand Down Expand Up @@ -161,7 +159,9 @@ def execute(
modalities: List[Modality],
task=None,
external_cache: Optional[LRUCache] = None,
) -> Dict[str, TransformedModality]:
enable_cache=True,
rep_cache: Dict[Any, TransformedModality] = None,
) -> Union[Dict[str, TransformedModality], TransformedModality]:
cache: Dict[str, TransformedModality] = {}
node_signatures: Dict[str, Hashable] = {}

Expand All @@ -175,7 +175,8 @@ def execute_node(node_id: str, task) -> TransformedModality:
modality = get_modality_by_id_and_instance_id(
modalities, node.modality_id, node.representation_index
)
cache[node_id] = modality
if enable_cache:
cache[node_id] = modality
node_signatures[node_id] = self._compute_leaf_signature(node)
return modality

Expand Down Expand Up @@ -203,7 +204,9 @@ def execute_node(node_id: str, task) -> TransformedModality:
elif isinstance(node_operation, AggregatedRepresentation):
result = node_operation.transform(input_mods[0])
elif isinstance(node_operation, UnimodalRepresentation):
if (
if rep_cache is not None:
result = rep_cache[node_operation.name]
elif (
isinstance(input_mods[0], TransformedModality)
and input_mods[0].transformation[0].__class__
== node.operation
Expand All @@ -228,13 +231,14 @@ def execute_node(node_id: str, task) -> TransformedModality:
if external_cache and is_unimodal:
external_cache.put(node_signature, result)

cache[node_id] = result
if enable_cache:
cache[node_id] = result
node_signatures[node_id] = node_signature
return result

execute_node(self.root_node_id, task)
result = execute_node(self.root_node_id, task)

return cache
return cache if enable_cache else result


def get_modality_by_id_and_instance_id(
Expand Down
65 changes: 50 additions & 15 deletions src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,32 @@

class UnimodalOptimizer:
def __init__(
self, modalities, tasks, debug=True, save_all_results=False, result_path=None
self,
modalities,
tasks,
debug=True,
save_all_results=False,
result_path=None,
k=2,
metric_name="accuracy",
):
self.modalities = modalities
self.tasks = tasks
self.modality_ids = [modality.modality_id for modality in modalities]
self.save_all_results = save_all_results
self.result_path = result_path

self.k = k
self.metric_name = metric_name
self.builders = {
modality.modality_id: RepresentationDAGBuilder() for modality in modalities
}

self.debug = debug

self.operator_registry = Registry()
self.operator_performance = UnimodalResults(modalities, tasks, debug, True)
self.operator_performance = UnimodalResults(
modalities, tasks, debug, True, k, metric_name
)

self._tasks_require_same_dims = True
self.expected_dimensions = tasks[0].expected_dim
Expand Down Expand Up @@ -185,12 +195,20 @@ def _process_modality(self, modality, parallel):
modality.modality_type
)
dags = []
operators = []
for operator in modality_specific_operators:
dags.extend(self._build_modality_dag(modality, operator()))
operators.append(operator())

external_cache = LRUCache(max_size=32)
rep_cache = None
if hasattr(modality, "data_loader") and modality.data_loader.chunk_size:
rep_cache = modality.apply_representations(operators)

for dag in dags:
representations = dag.execute([modality], external_cache=external_cache)
representations = dag.execute(
[modality], external_cache=external_cache, rep_cache=rep_cache
)
node_id = list(representations.keys())[-1]
node = dag.get_node_by_id(node_id)
if node.operation is None:
Expand Down Expand Up @@ -466,17 +484,26 @@ def temporal_context_operators(self, modality, builder, leaf_id, current_node_id


class UnimodalResults:
def __init__(self, modalities, tasks, debug=False, store_cache=True):
def __init__(
self,
modalities,
tasks,
debug=False,
store_cache=True,
k=-1,
metric_name="accuracy",
):
self.modality_ids = [modality.modality_id for modality in modalities]
self.task_names = [task.model.name for task in tasks]
self.results = {}
self.debug = debug
self.cache = {}
self.store_cache = store_cache

self.k = k
self.metric_name = metric_name
for modality in self.modality_ids:
self.results[modality] = {task_name: [] for task_name in self.task_names}
self.cache[modality] = {task_name: {} for task_name in self.task_names}
self.cache[modality] = {task_name: [] for task_name in self.task_names}

def add_result(self, scores, modality, task_name, task_time, combination, dag):
entry = ResultEntry(
Expand All @@ -491,12 +518,20 @@ def add_result(self, scores, modality, task_name, task_time, combination, dag):

self.results[modality.modality_id][task_name].append(entry)
if self.store_cache:
cache_key = (
id(dag),
scores[1],
modality.transform_time,
self.cache[modality.modality_id][task_name].append(modality)

results = self.results[modality.modality_id][task_name]
if self.k != -1 and len(results) > self.k:
ranked, sorted_indices = rank_by_tradeoff(
results, performance_metric_name=self.metric_name
)
self.cache[modality.modality_id][task_name][cache_key] = modality
keep = set(sorted_indices[: self.k])

self.cache[modality.modality_id][task_name] = [
m
for i, m in enumerate(self.cache[modality.modality_id][task_name])
if i in keep
]

if self.debug:
print(f"{modality.modality_id}_{task_name}: {entry}")
Expand All @@ -508,7 +543,7 @@ def print_results(self):
print(f"{modality}_{task_name}: {entry}")

def get_k_best_results(
self, modality, k, task, performance_metric_name, prune_cache=False
self, modality, task, performance_metric_name, prune_cache=False
):
"""
Get the k best results for the given modality
Expand All @@ -524,8 +559,8 @@ def get_k_best_results(
task_results, performance_metric_name=performance_metric_name
)

results = results[:k]
sorted_indices = sorted_indices[:k]
results = results[: self.k]
sorted_indices = sorted_indices[: self.k]

task_cache = self.cache.get(modality.modality_id, {}).get(task.model.name, None)
if not task_cache:
Expand Down
98 changes: 80 additions & 18 deletions src/main/python/systemds/scuro/modality/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,24 +141,86 @@ def pad(self, value=0, max_len=None):
else:
raise "Needs padding to max_len"
except:
maxlen = (
max([len(seq) for seq in self.data]) if max_len is None else max_len
)

result = np.full((len(self.data), maxlen), value, dtype=self.data_type)

for i, seq in enumerate(self.data):
data = seq[:maxlen]
result[i, : len(data)] = data

if self.has_metadata():
attention_mask = np.zeros(result.shape[1], dtype=np.int8)
attention_mask[: len(seq[:maxlen])] = 1
md_key = list(self.metadata.keys())[i]
if "attention_mask" in self.metadata[md_key]:
self.metadata[md_key]["attention_mask"] = attention_mask
else:
self.metadata[md_key].update({"attention_mask": attention_mask})
first = self.data[0]
if isinstance(first, np.ndarray) and first.ndim == 3:
maxlen = (
max([seq.shape[0] for seq in self.data])
if max_len is None
else max_len
)
tail_shape = first.shape[1:]
result = np.full(
(len(self.data), maxlen, *tail_shape),
value,
dtype=self.data_type or first.dtype,
)
for i, seq in enumerate(self.data):
data = seq[:maxlen]
result[i, : len(data), ...] = data
if self.has_metadata():
attention_mask = np.zeros(maxlen, dtype=np.int8)
attention_mask[: len(data)] = 1
md_key = list(self.metadata.keys())[i]
if "attention_mask" in self.metadata[md_key]:
self.metadata[md_key]["attention_mask"] = attention_mask
else:
self.metadata[md_key].update(
{"attention_mask": attention_mask}
)
elif (
isinstance(first, list)
and len(first) > 0
and isinstance(first[0], np.ndarray)
and first[0].ndim == 2
):
maxlen = (
max([len(seq) for seq in self.data]) if max_len is None else max_len
)
row_dim, col_dim = first[0].shape
result = np.full(
(len(self.data), maxlen, row_dim, col_dim),
value,
dtype=self.data_type or first[0].dtype,
)
for i, seq in enumerate(self.data):
data = seq[:maxlen]
# stack list of 2D arrays into 3D then assign
if len(data) > 0:
result[i, : len(data), :, :] = np.stack(data, axis=0)
if self.has_metadata():
attention_mask = np.zeros(maxlen, dtype=np.int8)
attention_mask[: len(data)] = 1
md_key = list(self.metadata.keys())[i]
if "attention_mask" in self.metadata[md_key]:
self.metadata[md_key]["attention_mask"] = attention_mask
else:
self.metadata[md_key].update(
{"attention_mask": attention_mask}
)
else:
maxlen = (
max([len(seq) for seq in self.data]) if max_len is None else max_len
)
result = np.full((len(self.data), maxlen), value, dtype=self.data_type)
for i, seq in enumerate(self.data):
data = seq[:maxlen]
try:
result[i, : len(data)] = data
except:
print(f"Error padding data for modality {self.modality_id}")
print(f"Data shape: {data.shape}")
print(f"Result shape: {result.shape}")
raise Exception("Error padding data")
if self.has_metadata():
attention_mask = np.zeros(result.shape[1], dtype=np.int8)
attention_mask[: len(data)] = 1
md_key = list(self.metadata.keys())[i]
if "attention_mask" in self.metadata[md_key]:
self.metadata[md_key]["attention_mask"] = attention_mask
else:
self.metadata[md_key].update(
{"attention_mask": attention_mask}
)
# TODO: this might need to be a new modality (otherwise we loose the original data)
self.data = result

Expand Down
Loading
Loading