Source code for internutopia.core.task.metric
from abc import ABC, abstractmethod
from internutopia.core.config.metric import MetricCfg
from internutopia.core.task_config_manager.base import TaskCfg
[docs]class BaseMetric(ABC):
metrics = {}
def __init__(self, config: MetricCfg, task_config: TaskCfg):
self.env_offset = None
self.env_id = None
self.task_name = None
self.config = config
self.name = config.name
self.task_config = task_config
self.metric_config = config.metric_config
[docs] @abstractmethod
def update(self, *args):
"""
This function is called at each world step.
"""
raise NotImplementedError(f'`update` function of {self.name} is not implemented')
[docs] @abstractmethod
def calc(self):
"""
This function is called to calculate the metrics when the episode is terminated.
"""
raise NotImplementedError(f'`calc` function of {self.name} is not implemented')
[docs] @classmethod
def register(cls, name: str):
"""
Register a metric class with the given name(decorator).
Args:
name(str): name of the metric
"""
def wrapper(metric_class):
"""
Register the metric class.
"""
cls.metrics[name] = metric_class
return metric_class
return wrapper
def set_up_runtime(self, task_name, env_id, env_offset):
self.task_name = task_name
self.env_id = env_id
self.env_offset = env_offset
def create_metric(config: MetricCfg, task_config: TaskCfg):
if config.type not in BaseMetric.metrics:
raise KeyError(
f"""The metric {config.type} is not registered, please register it using `@BaseMetric.register`"""
)
metric_cls = BaseMetric.metrics[config.type]
return metric_cls(config, task_config)