from collections import defaultdict
from .log_utils import default_logger as logger
[docs]class Registrable:
"""Any class that inherits from ``Registrable`` gains access to a named registry for its subclasses. To register them, just decorate them with the classmethod ``@BaseClass.register(name)``.
After which you can call ``BaseClass.list_available()`` to get the keys for the registered subclasses, and ``BaseClass.by_name(name)`` to get the corresponding subclass.
Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call ``from_params(params)`` on the returned subclass.
"""
_registry = defaultdict(dict)
_default_impl = None
[docs] @classmethod
def register(cls, name, constructor=None, overwrite=False):
"""Register a class under a particular name.
Args:
name (str): The name to register the class under.
constructor (str): optional (default=None)
The name of the method to use on the class to construct the object. If this is given,
we will use this method (which must be a ``classmethod``) instead of the default
constructor.
overwrite (bool) : optional (default=False)
If True, overwrites any existing models registered under ``name``. Else,
throws an error if a model is already registered under ``name``.
# Examples
To use this class, you would typically have a base class that inherits from ``Registrable``:
```python
class Transform(Registrable):
...
```
Then, if you want to register a subclass, you decorate it like this:
```python
@Transform.register("shift-transform")
class ShiftTransform(Transform):
def __init__(self, param1: int, param2: str):
...
```
Registering a class like this will let you instantiate a class from a config file, where you
give ``"type": "shift-transform"``, and keys corresponding to the parameters of the ``__init__``
method (note that for this to work, those parameters must have type annotations).
If you want to have the instantiation from a config file call a method other than the
constructor, either because you have several different construction paths that could be
taken for the same object (as we do in ``Transform``) or because you have logic you want to
happen before you get to the constructor, you can register a specific ``@classmethod`` as the constructor to use.
"""
registry = Registrable._registry[cls]
def add_subclass_to_registry(subclass):
# Add to registry, raise an error if key has already been used.
if name in registry:
if overwrite:
message = (
f"{name} has already been registered as {registry[name][0].__name__}, but "
f"overwrite=True, so overwriting with {cls.__name__}"
)
logger.info(message)
else:
message = (
f"Cannot register {name} as {cls.__name__}; "
f"name already in use for {registry[name][0].__name__}"
)
raise RuntimeError(message)
registry[name] = (subclass, constructor)
return subclass
return add_subclass_to_registry
[docs] @classmethod
def by_name(cls, name):
"""
Returns a callable function that constructs an argument of the registered class. Because
you can register particular functions as constructors for specific names, this isn't
necessarily the ``__init__`` method of some class.
"""
logger.debug(f"instantiating registered subclass {name} of {cls}")
subclass, constructor = cls.resolve_class_name(name)
if not constructor:
return subclass
else:
return getattr(subclass, constructor)
[docs] @classmethod
def resolve_class_name(cls, name):
"""
Returns the subclass that corresponds to the given ``name``, along with the name of the
method that was registered as a constructor for that ``name``, if any.
This method also allows ``name`` to be a fully-specified module name, instead of a name that
was already added to the ``Registry``. In that case, you cannot use a separate function as
a constructor (as you need to call ``cls.register()`` in order to tell us what separate
function to use).
"""
if name in Registrable._registry[cls]:
subclass, constructor = Registrable._registry[cls].get(name)
return subclass, constructor
else:
for base_cls, v in Registrable._registry.items():
if name in v:
subclass, constructor = Registrable._registry[base_cls].get(name)
return subclass, constructor
if "." in name:
# This might be a fully qualified class name, so we'll try importing its "module"
# and finding it there.
parts = name.split(".")
submodule = ".".join(parts[:-1])
class_name = parts[-1]
import importlib
try:
module = importlib.import_module(submodule)
except ModuleNotFoundError:
raise RuntimeError(
f"tried to interpret {name} as a path to a class "
f"but unable to import module {submodule}"
)
try:
subclass = getattr(module, class_name)
constructor = None
return subclass, constructor
except AttributeError:
raise RuntimeError(
f"tried to interpret {name} as a path to a class "
f"but unable to find class {class_name} in {submodule}"
)
else:
# is not a qualified class name
raise RuntimeError(
f"{name} is not a registered name for {cls.__name__}. "
"You probably need to use the --include-package flag "
"to load your custom code. Alternatively, you can specify your choices "
"""using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """
"in which case they will be automatically imported correctly."
)
[docs] @classmethod
def list_available(cls):
"""List default first if it exists"""
keys = list(Registrable._registry[cls].keys())
default = cls._default_impl
if default is None:
return keys
elif default not in keys:
raise RuntimeError(f"Default implementation {default} is not registered")
else:
return [default] + [k for k in keys if k != default]
[docs] @classmethod
def registry_dict(cls):
return Registrable._registry[cls]