fix grad accumulation, not overriding

This commit is contained in:
publicmatt
2024-02-22 08:49:54 -08:00
parent 6cf0186544
commit 4c9f7d8d7d
2 changed files with 12 additions and 7 deletions

View File

@@ -61,7 +61,6 @@ def test_large_backprop():
assert pytest.approx(w2.grad, 0.001) == 0.0
@pytest.mark.skip(reason="non-deterministic")
def test_auto_diff():
# inputs
x1 = Value(2.0, label='x1')
@@ -89,7 +88,6 @@ def test_auto_diff():
y.backward()
assert pytest.approx(n.grad, 0.001) == 0.5
assert h.grad == 0.0
assert pytest.approx(b.grad, 0.001) == 0.5
assert pytest.approx(h.grad, 0.001) == 0.5
@@ -102,3 +100,10 @@ def test_auto_diff():
assert pytest.approx(x2.grad, 0.001) == 0.5
assert pytest.approx(w2.grad, 0.001) == 0.0
def test_accumulation():
a = Value(2.0, label='x1')
b = a + a
b.backward()
assert a.grad == 2.0