fix grad accumulation, not overriding

This commit is contained in:
publicmatt 2024-02-22 08:49:54 -08:00
parent 6cf0186544
commit 4c9f7d8d7d
2 changed files with 12 additions and 7 deletions

View File

@ -34,8 +34,8 @@ class Value:
_children=(self, other), _op='+') _children=(self, other), _op='+')
def _backward(): def _backward():
self.grad = 1.0 * out.grad self.grad += 1.0 * out.grad
other.grad = 1.0 * out.grad other.grad += 1.0 * out.grad
out._backward = _backward out._backward = _backward
return out return out
@ -49,8 +49,8 @@ class Value:
_children=(self, other), _op='*') _children=(self, other), _op='*')
def _backward(): def _backward():
self.grad = other.data * out.grad self.grad += other.data * out.grad
other.grad = self.data * out.grad other.grad += self.data * out.grad
out._backward = _backward out._backward = _backward
return out return out
@ -61,6 +61,6 @@ class Value:
_op='tanh', label='tanh') _op='tanh', label='tanh')
def _backward(): def _backward():
self.grad = (1 - t**2) * out.grad self.grad += (1 - t**2) * out.grad
out._backward = _backward out._backward = _backward
return out return out

View File

@ -61,7 +61,6 @@ def test_large_backprop():
assert pytest.approx(w2.grad, 0.001) == 0.0 assert pytest.approx(w2.grad, 0.001) == 0.0
@pytest.mark.skip(reason="non-deterministic")
def test_auto_diff(): def test_auto_diff():
# inputs # inputs
x1 = Value(2.0, label='x1') x1 = Value(2.0, label='x1')
@ -89,7 +88,6 @@ def test_auto_diff():
y.backward() y.backward()
assert pytest.approx(n.grad, 0.001) == 0.5 assert pytest.approx(n.grad, 0.001) == 0.5
assert h.grad == 0.0
assert pytest.approx(b.grad, 0.001) == 0.5 assert pytest.approx(b.grad, 0.001) == 0.5
assert pytest.approx(h.grad, 0.001) == 0.5 assert pytest.approx(h.grad, 0.001) == 0.5
@ -102,3 +100,10 @@ def test_auto_diff():
assert pytest.approx(x2.grad, 0.001) == 0.5 assert pytest.approx(x2.grad, 0.001) == 0.5
assert pytest.approx(w2.grad, 0.001) == 0.0 assert pytest.approx(w2.grad, 0.001) == 0.0
def test_accumulation():
a = Value(2.0, label='x1')
b = a + a
b.backward()
assert a.grad == 2.0