in projects/home/recap/model/feature_transform.py [0:0]
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
super().__init__()
if config.axis != -1:
raise NotImplementedError
if config.center != config.scale:
raise ValueError(
f"Center and scale must match in torch, received {config.center}, {config.scale}"
)
self.layer = torch.nn.LayerNorm(
normalized_shape, eps=config.epsilon, elementwise_affine=config.center
)