-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlearn.go
270 lines (233 loc) · 6.03 KB
/
learn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
package bpe
import (
"strings"
"github.com/go-nlp/corpus"
)
// Tokenizer is a function that tokenizes a string. This library provides a simple tokenizer.
type Tokenizer func(a string) []string
// SimpleTokenizer is a simple tokenizer of text
func SimpleTokenizer(a string) []string { return strings.Split(strings.Trim(a, "\r\n "), " ") }
// Statistics is the statistics of a corpus, used to figure out which pairs to replace.
type Statistics struct {
Stats map[Pair]int
Indices map[Pair]map[int]int
Corpus *corpus.Corpus
MaxRune rune
}
// PairStats returns the occurence frequencies of pairs of runes. It also construct an index of pairs to the word ID along its frequency
func PairStats(c *corpus.Corpus, opts ...FuncOpt) Statistics {
stats := make(map[Pair]int)
indices := make(map[Pair]map[int]int) // pair:{wordid:freq}
// we only rely on the markEOW funcmod
var m funcMod
for _, o := range opts {
o(&m)
}
m.buf = nil
// ENHANCEMENT: indices should have its own data struct
var maxRune rune
for i := 0; i < c.Size(); i++ {
word, _ := c.Word(i)
freq := c.WordFreq(word)
ps := Pairs(word, opts...)
for j, p := range ps {
// for replacement rune
if r := p.Fst(); r > maxRune {
maxRune = r
}
if r := p.Snd(); r > maxRune {
maxRune = r
}
if j == len(ps)-1 && m.markEOW {
p.snd = rune(-(int32(p.snd))) // the negative is a hack to mark the end of a word symbol
}
stats[p] += freq
if indices[p] == nil {
indices[p] = make(map[int]int)
}
indices[p][i]++
}
}
return Statistics{
Stats: stats,
Indices: indices,
Corpus: c,
MaxRune: maxRune,
}
}
// Encoder represents a state that may be used to encode a word
type Encoder struct {
Corpus *corpus.Corpus
Pairs []Pair
Replacements map[Pair]rune
MaxRune rune
}
// Learn learns an Encoder from the given data in the corpus in the input.
func Learn(c *corpus.Corpus, symbols, minFreq int, markEOW bool) (retVal Encoder, err error) {
// if there are any preallocated []Pair that is being used, they will be safe for reuse once this function finishes
stats := PairStats(c, MarkEOW(markEOW))
var list []Pair
rep := make(map[Pair]rune)
for i := 0; i < symbols; i++ {
m := mode(stats.Stats)
// TODO: probably missing
if stats.Stats[m] < minFreq {
break // TODO error
}
replacements := replacePair(stats, m)
updateStats(&stats, replacements, m)
rep[m] = stats.MaxRune
list = append(list, m)
//log.Printf("mode %v replacements %v", m, replacements)
}
return Encoder{
Corpus: c,
Pairs: list,
Replacements: rep,
MaxRune: stats.MaxRune,
}, nil
}
// replacePair returns a list of replacements
func replacePair(stats Statistics, old Pair) (retVal []replacedWord) {
c := stats.Corpus
maxRune := stats.MaxRune
indices := stats.Indices
replacement := maxRune + 1
retVal = make([]replacedWord, 0, len(indices[old]))
for id, freq := range indices[old] {
if freq < 1 {
continue
}
word, _ := c.Word(id)
newWord := replaceInString(word, old, replacement)
c.ReplaceWord(id, newWord)
retVal = append(retVal, replacedWord{id, word})
}
return retVal
}
// updateStats must be called immediately after replacePair
func updateStats(stats *Statistics, replacements []replacedWord, old Pair) {
rr := stats.MaxRune + 1
for _, r := range replacements {
original := r.original
ps := Pairs(original)
is := indicesOf(old, ps)
for _, i := range is {
switch i {
case -1:
// error
case 0:
if len(ps) == 1 {
continue
}
// replace next
next := ps[i+1]
p := P(rr, next.Snd())
// update stats
stats.Stats[next]--
stats.Stats[p]++
updateIndices(stats, next, p, r.id)
case len(ps) - 1:
// replace previous
prev := ps[i-1]
p := P(prev.Fst(), rr)
//update stats
stats.Stats[prev]--
stats.Stats[p]++
updateIndices(stats, prev, p, r.id)
default:
// replace previous and next
prev := ps[i-1]
p := P(prev.Fst(), rr)
// update stats
stats.Stats[prev]--
stats.Stats[p]++
updateIndices(stats, prev, p, r.id)
next := ps[i+1]
q := P(rr, next.Snd())
// update stats
stats.Stats[next]--
stats.Stats[q]++
updateIndices(stats, next, q, r.id)
}
}
}
delete(stats.Stats, old)
delete(stats.Indices, old)
stats.MaxRune++
}
func updateIndices(stats *Statistics, old, new Pair, wordID int) {
// reduce the count of the old pair
if _, ok := stats.Indices[old][wordID]; ok {
stats.Indices[old][wordID]--
}
if stats.Indices[old][wordID] <= 0 {
delete(stats.Indices[old], wordID)
}
// insert and update count of new pair
if stats.Indices[new] == nil {
stats.Indices[new] = make(map[int]int)
}
stats.Indices[new][wordID]++
}
// UTIL
func mode(a map[Pair]int) Pair {
var maxFreq int = -1
var max Pair
for k, v := range a {
// because Go's maps are nondeterministic,
// we have to also compare the internals of a pair should there be a match to maxFreq.
// This way we can always have deterministic results (makes testing easier)
//
// The choice of k < max is arbitrary
if v > maxFreq {
max = k
maxFreq = v
} else if v == maxFreq && (k.Fst() < max.Fst() || (k.Fst() == max.Fst() && k.Snd() < max.Snd())) {
max = k
maxFreq = v
}
}
return max
}
func replaceInString(s string, p Pair, r rune) string {
rs := []rune(s)
fst := p.Fst()
snd := p.Snd()
for i := 0; i < len(rs); i++ {
if rs[i] == fst && (i+1 < len(rs)) && rs[i+1] == snd {
rs[i] = r
if i+1 == len(rs)-1 {
rs = rs[:i+1]
break
}
copy(rs[i+1:], rs[i+2:])
rs = rs[:len(rs)-1]
}
}
return string(rs)
}
// tidyStats cleans up unused Pairs
func tidyStats(stats map[Pair]int) {
var dels []Pair
for k, v := range stats {
if v <= 0 {
dels = append(dels, k)
}
}
for _, d := range dels {
delete(stats, d)
}
}
// tidyIndices cleans up unused Pairs from the indices
func tidyIndices(indices map[Pair]map[int]int) {
var dels []Pair
for k, v := range indices {
if len(v) == 0 || v == nil {
dels = append(dels, k)
}
}
for _, d := range dels {
delete(indices, d)
}
}