Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training resets (destroy) cloned (or saved) Neural Nets #949

Open
aurium opened this issue Dec 30, 2024 · 0 comments · May be fixed by #951
Open

Training resets (destroy) cloned (or saved) Neural Nets #949

aurium opened this issue Dec 30, 2024 · 0 comments · May be fixed by #951
Labels

Comments

@aurium
Copy link

aurium commented Dec 30, 2024

What is wrong?

I'm trying to save a LSTM net to localStorage to reuse and continue the training after reloading a web page.
However, any little training resets the ANN after every reload.

Where does it happen?

On the Firefox web browser.

How do we replicate the issue?

  1. Create a LSTM net.
  2. Train it.
  3. Create a second LSTM net.
  4. Load the first net data onto the second.
  5. See both outputs the same.
  6. Train again the first without problem, the error rate evolves normally.
  7. Train the second, the error rate explodes.
    (I'll add a test code bellow)

Expected behavior (i.e. solution)

The cloned (or loaded after save) net should continues evolving from the point where the original stopped.

Version information

Nodejs: null

Browser: Firefox 128

Brain.js: https://unpkg.com/[email protected]

How important is this (1-5)?

5

Other Comments

Test code:

<script src="http://unpkg.com/brain.js"></script>
<script>
const net = new brain.recurrent.LSTM({ hiddenLayers: [60, 60] })
net.maxPredictionLength = 100

const trainData = [
  'doe, a deer, a female deer',
  'ray, a drop of golden sun',
  'me, a name I call myself',
]

// First train
net.train(trainData, {
  iterations: 5000,
  log: true,
  logPeriod: 500,
  learningRate: 0.2,
})

// Clone the net:
const net2 = new brain.recurrent.LSTM({ hiddenLayers: [60, 60] })
net2.fromJSON(net.toJSON())

// Both output the same text:
console.log('ray 1:', net.run('ray'))
console.log('ray 2:', net2.run('ray'))

// More training, start from the last error rate:
net.train(trainData, {
  iterations: 30,
  log: true,
  logPeriod: 10,
  learningRate: 0.2,
})

// More training to the clone:
net2.train(trainData, {
  iterations: 30,
  log: true,
  logPeriod: 10,
  learningRate: 0.2,
})
// (???) That started with a BIG error rate!

// The first reduced the quality, but the second is crazy:
console.log('ray 1:', net.run('ray'))
console.log('ray 2:', net2.run('ray'))
</script>

Example output:

iterations: 0, training error: Infinity
iterations: 500, training error: 0.01295498762785306
iterations: 1000, training error: 0.01216566726130864
iterations: 1500, training error: 0.012144481239444045
iterations: 2000, training error: 0.012128375937731972
iterations: 2500, training error: 186331.85451823566
iterations: 3000, training error: 0.013161829605093137
iterations: 3500, training error: 0.012466912897993345
iterations: 4000, training error: 0.013913231326531402
iterations: 4500, training error: 0.012443924643591718
ray 1: , a drop of golden sun
ray 2: , a drop of golden sun
iterations: 0, training error: 0.01232396260956039    <-- re-Training the original
iterations: 10, training error: 0.012362680081652325
iterations: 20, training error: 0.056502526340651955
iterations: 0, training error: Infinity               <-- Training the clone
iterations: 10, training error: 403136894.7168553
iterations: 20, training error: 469362.6213782614
ray 1: , a f g n of go go go goldnfseuff go go go goldnfseuff go go go golden
ray 2: eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee
@aurium aurium added the bug label Dec 30, 2024
rizmyabdulla added a commit to rizmyabdulla/brain.js that referenced this issue Jan 18, 2025
Fixes BrainJS#949

Update `src/recurrent.ts` to ensure cloned LSTM nets continue training from the point where the original stopped.

* Add `fromJSON` method to properly restore the training state.
* Modify `train` method to account for the state of the cloned net.
* Update `trainPattern` method to consider the previous training state of the cloned net.
* Adjust `initialize` method to handle state restoration for cloned nets.
* Ensure `runInputs` method maintains continuity in training for cloned nets.

Add a test case in `src/recurrent/lstm.test.ts` to verify that training a cloned LSTM net continues evolving from the point where the original stopped.
@rizmyabdulla rizmyabdulla linked a pull request Jan 18, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant