JAX example, Class method
JAX의 코드 작성방식을 다양한 예시를 통해 학습한다.
1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import jax
import jax.numpy as jnp
from jax import vmap, jit
from functools import partial
jax.config.update("jax_enable_x64", True)
class A():
def __init__(self, a: jnp.array)->None:
self.a = a
self.Init()
def Init(self)->None:
self.b = None
def set_b(self, x):
self.b = x
@partial(jit, static_argnums=(0,))
def f(self, var: float)->float:
b = self.a * var
return b
objA = A(jnp.array([2.0]))
print("1)",objA.a, objA.b)
b = objA.f(10.)
print("2)",b)
objA.set_b(b)
print("3)",objA.a, objA.b)
new_objA = A(jnp.array([3.0]))
print("4)",objA.a, objA.b)
print("5)",new_objA.a, new_objA.b)
new_b = new_objA.f(20.)
new_objA.set_b(new_b)
print("6)",objA.a, objA.b)
print("7)",new_objA.a, new_objA.b)
1
2
3
4
5
6
7
1) [2.] None
2) [20.]
3) [2.] [20.]
4) [2.] [20.]
5) [3.] None
6) [2.] [20.]
7) [3.] [60.]
문제
1
2
3
4
5
6
7
obj = A(1)
print(obj.f(2))
# 2
obj.a = 2
print(obj.f(2)) # should print 4, but prints 2
# 2
답변
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class A():
def __init__(self, a: jnp.array)->None:
self.a = a
self.Init()
def Init(self)->None:
self.b = None
def set_b(self, x):
self.b = x
@jit
def f(self, var: float)->float:
b = self.a * var
return b
def _tree_flatten(self):
# You might also want to store self.b in either the first group
# (if it's not hashable) or the second group (if it's hashable)
return (self.a,), ()
@classmethod
def _tree_unflatten(cls, aux, children):
return cls(*children)
tree_util.register_pytree_node(A, A._tree_flatten, A._tree_unflatten)
obj = A(1)
print(obj.f(2))
# 2
obj.a = 2
print(obj.f(2))
# 4
2
질문
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import jax.numpy as np
from jax import jit
from functools import partial
class World:
def __init__(self, p, v):
self.p = p
self.v = v
@partial(jit, static_argnums=(0,))
def step(self, dt):
a = - 9.8
self.v += a * dt
self.p += self.v *dt
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world.step(0.01)
print(world.p)
답변
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import jax.numpy as np
from jax import jit
from collections import namedtuple
World = namedtuple("World", ["p", "v"])
@jit
def step(world, dt):
a = -9.8
new_v = world.v + a * dt
new_p = world.p + new_v * dt
return World(new_p, new_v)
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world = step(world, 0.01)
print(world.p)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from jax.tree_util import register_pytree_node
from functools import partial
class World:
def __init__(self, p, v):
self.p = p
self.v = v
@jit
def step(self, dt):
a = -9.8
new_v = self.v + a * dt
new_p = self.p + new_v * dt
return World(new_p, new_v)
# By registering 'World' as a pytree, it turns into a transparent container and
# can be used as an argument to any JAX-transformed functions.
register_pytree_node(World,
lambda x: ((x.p, x.v), None),
lambda _, tup: World(tup[0], tup[1]))
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world = world.step(0.01)
print(world.p)