diff --git a/README.md b/README.md index a4e813f..d148ab2 100644 --- a/README.md +++ b/README.md @@ -43,3 +43,33 @@ func main() { fmt.Println(sego.SegmentsToString(segments, false)) } ``` + + +```go +package main + +import ( + "fmt" + "github.com/huichen/sego" +) + +func init() { + // 注册分词器 + sego.RegisterDefaultSegmenter("github.com/huichen/sego/data/dictionary.txt") + sego.RegisterSegmenter("test", "github.com/huichen/sego/data/dictionary.txt") +} + +func main() { + // 载入词典 + segmenter := sego.GetDefaultSegmenter() + segmenter2 := sego.GetSegmenter("test") + + // 分词 + text := []byte("中华人民共和国中央人民政府") + segments := segmenter.Segment(text) + + // 处理分词结果 + // 支持普通模式和搜索模式两种分词,见代码中SegmentsToString函数的注释。 + fmt.Println(sego.SegmentsToString(segments, false)) +} +``` \ No newline at end of file diff --git a/segmenter.go b/segmenter.go index ff2f4a8..7c161d8 100644 --- a/segmenter.go +++ b/segmenter.go @@ -3,7 +3,8 @@ package sego import ( "bufio" - "fmt" + "errors" + "io" "log" "math" "os" @@ -17,6 +18,42 @@ const ( minTokenFrequency = 2 // 仅从字典文件中读取大于等于此频率的分词 ) +var DefaultSegmenter = new(Segmenter) + +var segmenters = map[string]*Segmenter{ + "default": DefaultSegmenter, +} + +// 注册分词器 +func RegisterSegmenter(alias string, files string) error { + seg := new(Segmenter) + err := seg.LoadDictionary(files) + if err != nil { + return err + } + + segmenters[alias] = seg + return nil +} + +// 获取分词器 +func GetSegmenter(alias string) *Segmenter { + if _, ok := segmenters[alias]; ok { + return segmenters[alias] + } + return nil +} + +// 注册默认分词器 +func RegisterDefaultSegmenter(files string) error { + return RegisterSegmenter("default", files) +} + +// 获取默认分词器 +func GetDefaultSegmenter() *Segmenter { + return GetSegmenter("default") +} + // 分词器结构体 type Segmenter struct { dict *Dictionary @@ -40,44 +77,46 @@ func (seg *Segmenter) Dictionary() *Dictionary { // 当一个分词既出现在用户词典也出现在通用词典中,则优先使用用户词典。 // // 词典的格式为(每个分词一行): -// 分词文本 频率 词性 -func (seg *Segmenter) LoadDictionary(files string) { +// 分词文本 频率 词性 (强制要求使用这个格式,不符合这个格式的行跳过,并且在log中可以看到) +func (seg *Segmenter) LoadDictionary(files string) error { seg.dict = new(Dictionary) for _, file := range strings.Split(files, ",") { log.Printf("载入sego词典 %s", file) dictFile, err := os.Open(file) defer dictFile.Close() if err != nil { - log.Fatalf("无法载入字典文件 \"%s\" \n", file) + log.Printf("无法载入字典文件 \"%s\" \n", file) + return err } reader := bufio.NewReader(dictFile) var text string - var freqText string var frequency int var pos string // 逐行读入分词 + line := 0 for { - size, _ := fmt.Fscanln(reader, &text, &freqText, &pos) - - if size == 0 { - // 文件结束 - break - } else if size < 2 { - // 无效行 + line++ + txt, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + // 文件结束 + break + } + log.Printf("%v 文件第 %v行读取错误,跳过: %v", file, line, err.Error()) continue - } else if size == 2 { - // 没有词性标注时设为空字符串 - pos = "" } - // 解析词频 - var err error - frequency, err = strconv.Atoi(freqText) - if err != nil { + parts := strings.Split(txt, " ") + if len(parts) < 3 { + log.Printf("%v 文件第 %v行读取错误,跳过: %v", file, line, "读取个数少于三个") continue } + N := len(parts) + text = strings.Join(parts[:N-2], " ") + frequency, err = strconv.Atoi(parts[N-2]) + pos = parts[N-1] // 过滤频率太小的词 if frequency < minTokenFrequency { @@ -123,6 +162,19 @@ func (seg *Segmenter) LoadDictionary(files string) { } log.Println("sego词典载入完毕") + return nil +} + +func parse(line string) (txt string, prequence int, pos string, err error) { + parts := strings.Split(line, " ") + if len(parts) < 3 { + return "", 0, "", errors.New("incomplete line") + } + N := len(parts) + txt = strings.Join(parts[:N-2], " ") + prequence, err = strconv.Atoi(parts[N-2]) + pos = parts[N-1] + return } // 对文本分词