Source code for gensbi.recipes.utils
import jax
from jax import numpy as jnp
import numpy as np
from typing import Union, Tuple
from einops import repeat, rearrange
from jax import Array
[docs]
def init_ids_joint(dim_obs: int, dim_cond: int):
dim_joint = dim_obs + dim_cond
node_ids = jnp.arange(dim_joint).reshape((1, -1, 1))
obs_ids = jnp.arange(dim_obs).reshape((1, -1, 1)) # observation ids
cond_ids = jnp.arange(dim_obs, dim_joint).reshape((1, -1, 1)) # conditional ids
return node_ids, obs_ids, cond_ids
[docs]
def init_ids_1d(dim: int, semantic_id: Union[int, None] = None):
if semantic_id is None:
ids = np.zeros((1, dim, 1), dtype=np.int32)
else:
ids = np.zeros((1, dim, 2), dtype=np.int32)
ids[..., 1] = semantic_id
ids[0, :, 0] = np.arange(dim)
return jnp.array(ids, dtype=jnp.int32)
[docs]
def init_ids_2d(dim: Tuple[int, int], semantic_id: int = 0):
img_ids = np.zeros((dim[0] // 2, dim[1] // 2, 3), dtype=np.int32)
img_ids[..., 0] = semantic_id
img_ids[..., 1] = img_ids[..., 1] + np.arange(dim[0] // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + np.arange(dim[1] // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=1)
return jnp.array(img_ids, dtype=jnp.int32)
@jax.jit
[docs]
def patchify_2d(x: Array):
return rearrange(x, "b (h ph) (w pw) c -> b (h w) (c ph pw)", ph=2, pw=2)