2024-03-18 19:45:15 -07:00
|
|
|
import pytest
|
2024-02-22 07:15:45 -08:00
|
|
|
from micrograd.engine import Value
|
|
|
|
|
|
|
|
|
|
|
|
def test_value_init():
|
|
|
|
v = Value(1)
|
|
|
|
assert v.data == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_value_repr():
|
|
|
|
v = Value(2.0)
|
|
|
|
assert "Value(data=2.0)" == repr(v)
|
|
|
|
|
|
|
|
|
2024-03-18 19:45:15 -07:00
|
|
|
def test_value_add_opt():
|
2024-02-22 07:15:45 -08:00
|
|
|
v1 = Value(2.0)
|
|
|
|
v2 = Value(4.0)
|
|
|
|
assert (v1 + v2).data == 6.0
|
|
|
|
assert "Value(data=6.0)" == repr(v1 + v2)
|
|
|
|
|
|
|
|
|
2024-03-18 19:45:15 -07:00
|
|
|
def test_value_sub_opt():
|
2024-02-22 07:15:45 -08:00
|
|
|
v1 = Value(2.0)
|
|
|
|
v2 = Value(4.0)
|
|
|
|
assert (v1 - v2).data == -2.0
|
|
|
|
assert "Value(data=-2.0)" == repr(v1 - v2)
|
|
|
|
|
|
|
|
|
2024-03-18 19:45:15 -07:00
|
|
|
def test_value_mul_opt():
|
2024-02-22 07:15:45 -08:00
|
|
|
v1 = Value(2.0)
|
|
|
|
v2 = Value(4.0)
|
|
|
|
v3 = Value(-1.0)
|
|
|
|
assert (v1 * v2).data == 8.0
|
|
|
|
assert (v1 * v3).data == -2.0
|
|
|
|
|
|
|
|
|
2024-03-18 19:45:15 -07:00
|
|
|
def test_value_rmul_opt():
|
|
|
|
a = Value(2.0)
|
|
|
|
b = 2 * a
|
|
|
|
assert b.data == 4.0
|
|
|
|
|
|
|
|
|
|
|
|
def test_value_pow_opt():
|
|
|
|
a = Value(2.0)
|
|
|
|
b = a ** 2
|
|
|
|
assert b.data == 4.0
|
|
|
|
|
|
|
|
|
|
|
|
def test_value_exp_opt():
|
|
|
|
a = Value(1.0)
|
|
|
|
b = a.exp()
|
|
|
|
assert pytest.approx(b.data, 0.1) == 2.7
|
|
|
|
|
|
|
|
|
|
|
|
def test_value_int_opt():
|
|
|
|
a = Value(2.0)
|
|
|
|
b = a - 1
|
|
|
|
assert b.data == 1.0
|
|
|
|
|
|
|
|
|
|
|
|
def test_value_div_opt():
|
|
|
|
a = Value(2.0)
|
|
|
|
b = a / 2
|
|
|
|
assert b.data == 1.0
|
|
|
|
|
|
|
|
|
2024-02-22 07:15:45 -08:00
|
|
|
def test_value_mul_add():
|
|
|
|
v1 = Value(2.0)
|
|
|
|
v2 = Value(4.0)
|
|
|
|
v3 = Value(-1.0)
|
|
|
|
assert ((v1 * v3) + v2).data == 2.0
|
|
|
|
|
|
|
|
|
|
|
|
def test_children():
|
|
|
|
v1 = Value(2.0)
|
|
|
|
v2 = Value(4.0)
|
|
|
|
out = v1 * v2
|
|
|
|
assert len(out._prev) == 2
|
|
|
|
assert v1 in out._prev
|
|
|
|
assert v2 in out._prev
|
|
|
|
|
|
|
|
|
|
|
|
def test_operations():
|
|
|
|
v1 = Value(2.0)
|
|
|
|
v2 = Value(4.0)
|
|
|
|
mul = v1 * v2
|
|
|
|
add = v1 + v2
|
|
|
|
assert mul._op == '*'
|
|
|
|
assert add._op == '+'
|
2024-02-22 08:41:04 -08:00
|
|
|
|
|
|
|
|
|
|
|
def test_tanh():
|
|
|
|
t = Value(2.0).tanh()
|
|
|
|
assert t.data > 0
|
|
|
|
assert t.data < 1
|
|
|
|
|
|
|
|
t = Value(0.0).tanh()
|
|
|
|
assert t.data == 0
|
|
|
|
|
|
|
|
t = Value(-2.0).tanh()
|
|
|
|
assert t.data < 0
|
|
|
|
assert t.data > -1
|