Source code for enerzyme.data.datatype

IS_INT = 1
IS_ROUNDED = 2
IS_ATOMIC = 4
REQUIRES_GRAD = 8
IS_IDX = 16
IS_TARGET = 32
IS_GRAD = 64
TENSOR_RANK_BIT = 8
TENSOR_RANK_BASE = 2 << TENSOR_RANK_BIT
DATA_TYPES = {
    "N": IS_INT | IS_IDX,
    "Za": IS_INT | IS_ATOMIC,
    "Ra": IS_ATOMIC | (TENSOR_RANK_BASE * 1) | REQUIRES_GRAD,
    "Q": IS_ROUNDED | IS_TARGET,
    "Qa": IS_ATOMIC | IS_TARGET,
    "S": IS_ROUNDED | IS_TARGET,
    "E": IS_TARGET,
    "Fa": IS_ATOMIC | (TENSOR_RANK_BASE * 1) | IS_TARGET | IS_GRAD,
    "M2": TENSOR_RANK_BASE * 1 | IS_TARGET,
    "M2a": IS_ATOMIC | (TENSOR_RANK_BASE * 1) | IS_TARGET,
    "idx_i": IS_INT | IS_IDX,
    "idx_j": IS_INT | IS_IDX,
    "N_pair": IS_INT | IS_IDX
}

[docs] def is_int(k): return bool(DATA_TYPES.get(k, 0) & IS_INT)
[docs] def is_rounded(k): return bool(DATA_TYPES.get(k, 0) & IS_ROUNDED)
[docs] def is_atomic(k): return bool(DATA_TYPES.get(k, 0) & IS_ATOMIC)
[docs] def requires_grad(k): return bool(DATA_TYPES.get(k, 0) & REQUIRES_GRAD)
[docs] def is_idx(k): return bool(DATA_TYPES.get(k, 0) & IS_IDX)
[docs] def is_target(k): return bool(DATA_TYPES.get(k, 0) & IS_TARGET)
[docs] def is_target_uq(k): if k.endswith("_var") or k.endswith("_std"): target = k[:-4] return is_target(target) return False
[docs] def is_grad(k): if bool(DATA_TYPES.get(k, 0) & IS_GRAD) or k.endswith("_grad"): return True return False
[docs] def get_tensor_rank(k): return bool(DATA_TYPES.get(k, 0) >> TENSOR_RANK_BIT)
TYPE_ATTRS = { "is_atomic": IS_ATOMIC, }
[docs] def register_data_type(k, **type_info): DATA_TYPES[k] = 0 for type_attr, v in type_info.items(): if v is True: DATA_TYPES[k] |= TYPE_ATTRS[type_attr] else: DATA_TYPES[k] &= ~TYPE_ATTRS[type_attr]
__all__ = ["is_int", "is_rounded", "is_atomic", "requires_grad", "is_idx", "get_tensor_rank", "is_target", "is_target_uq", "register_data_type", "is_grad"]