Skip to main content

敏感词过滤算法


敏感词过滤算法

敏感词过滤,说白了就是字符串匹配,看输入的字符串是否存在敏感词。

常见的比较好的算法有KMP,Tire树,AC自动机(KMP+Trie)

整体结构

我们定义SensitiveWordFilter接口,然后有不同算法的实现策略

三个核心功能,分别是过滤敏感词,判断是否命中敏感词,以及构建敏感词树

image-20240529202947037

敏感词过滤配置类

@Configuration
public class SensitiveWordConfig {

    @Autowired
    private MyWordFactory myWordFactory;

    /**
     * 初始化引导类
     *
     * @return 初始化引导类
     * @since 1.0.0
     */
    @Bean
    public SensitiveWordBs sensitiveWordBs() {
        return SensitiveWordBs.newInstance()
                .filterStrategy(DFAFilter.getInstance())
                .sensitiveWord(myWordFactory)
                .init();
    }
}

filterStrategy(DFAFilter.getInstance())可以替换不同的策略

使用:

@Resource
private SensitiveWordBs sensitiveWordBs;
sensitiveWordBs.filter(body.getContent())

下面是算法的实现代码:

AC自动机

ACTrie

/**
 * aho-corasick算法(又称AC自动机算法)
 */
@NotThreadSafe
public class ACTrie {

    // 根节点
    private ACTrieNode root;

    public ACTrie(List<String> words) {
        words = words.stream().distinct().collect(Collectors.toList()); // 去重
        root = new ACTrieNode();
        for (String word : words) {
            addWord(word);
        }
        initFailover();
    }

    public void addWord(String word) {
        ACTrieNode walkNode = root;
        char[] chars = word.toCharArray();
        for (int i = 0; i < word.length(); i++) {
            walkNode.addChildrenIfAbsent(chars[i]);
            walkNode = walkNode.childOf(chars[i]);
            walkNode.setDepth(i + 1);
        }
        walkNode.setLeaf(true);
    }

    /**
     * 初始化节点中的回退指针
     */
    private void initFailover() {
        // 第一层的fail指针指向root
        Queue<ACTrieNode> queue = new LinkedList<>();
        Map<Character, ACTrieNode> children = root.getChildren();
        for (ACTrieNode node : children.values()) {
            node.setFailover(root);
            queue.offer(node);
        }
        // 构建剩余层数节点的fail指针,利用层次遍历
        while (!queue.isEmpty()) {
            ACTrieNode parentNode = queue.poll();
            for (Map.Entry<Character, ACTrieNode> entry : parentNode.getChildren().entrySet()) {
                ACTrieNode childNode = entry.getValue();
                ACTrieNode failover = parentNode.getFailover();
                // 在树中找到以childNode为结尾的字符串的最长前缀匹配,failover指向了这个最长前缀匹配的父节点
                while (failover != null && (!failover.hasChild(entry.getKey()))) {
                    failover = failover.getFailover();
                }
                // 回溯到了root节点
                if (failover == null) {
                    childNode.setFailover(root);
                } else {
                    // 更新当前节点的回退指针
                    childNode.setFailover(failover.childOf(entry.getKey()));
                }
                queue.offer(childNode);
            }
        }
    }

    /**
     * 查询句子中包含的敏感词的起始位置和结束位置
     *
     * @param text
     */
    public List<MatchResult> matches(String text) {
        List<MatchResult> result = Lists.newArrayList();
        ACTrieNode walkNode = root;
        for (int i = 0; i < text.length(); i++) {
            char c = text.charAt(i);
            while (!walkNode.hasChild(c) && walkNode.getFailover() != null) {
                walkNode = walkNode.getFailover();
            }
            // 如果因为当前节点的孩子节点有这个字符,则将walkNode替换为下面的孩子节点
            if (walkNode.hasChild(c)) {
                walkNode = walkNode.childOf(c);
                // 检索到了敏感词
                if (walkNode.isLeaf()) {
                    result.add(new MatchResult(i - walkNode.getDepth() + 1, i + 1));
                    // 模式串回退到最长可匹配前缀位置并开启新一轮的匹配
                    // 这种回退方式将一个不漏的匹配到所有的敏感词,匹配结果的区间可能会有重叠的部分
                    walkNode = walkNode.getFailover();
                }
            }
        }
        return result;
    }

}

ACTrieNode

@Getter
@Setter
public class ACTrieNode {

    // 子节点
    private Map<Character, ACTrieNode> children = Maps.newHashMap();

    // 匹配过程中,如果模式串不匹配,模式串指针会回退到failover继续进行匹配
    private ACTrieNode failover = null;

    private int depth;

    private boolean isLeaf = false;

    public void addChildrenIfAbsent(char c) {
        children.computeIfAbsent(c, (key) -> new ACTrieNode());
    }

    public ACTrieNode childOf(char c) {
        return children.get(c);
    }

    public boolean hasChild(char c) {
        return children.containsKey(c);
    }

    @Override
    public String toString() {
        return "ACTrieNode{" + "failover=" + failover + ", depth=" + depth + ", isLeaf=" + isLeaf + '}';
    }
}

ACFilter

/**
 * 基于ac自动机实现的敏感词过滤工具类
 * 可以用来替代{@link ConcurrentHistogram}
 * 为了兼容提供了相同的api接口 {@code hasSensitiveWord}
 */
public class ACFilter implements SensitiveWordFilter {

    private final static char mask_char = '*'; // 替代字符

    private static ACTrie ac_trie = null;

    /**
     * 有敏感词
     *
     * @param text 文本
     * @return boolean
     */
    public boolean hasSensitiveWord(String text) {
        if (StringUtils.isBlank(text)) return false;
        return !Objects.equals(filter(text), text);
    }

    /**
     * 敏感词替换
     *
     * @param text 待替换文本
     * @return 替换后的文本
     */
    public String filter(String text) {
        if (StringUtils.isBlank(text)) return text;
        List<MatchResult> matchResults = ac_trie.matches(text);
        StringBuffer result = new StringBuffer(text);
        // matchResults是按照startIndex排序的,因此可以通过不断更新endIndex最大值的方式算出尚未被替代部分
        int endIndex = 0;
        for (MatchResult matchResult : matchResults) {
            endIndex = Math.max(endIndex, matchResult.getEndIndex());
            replaceBetween(result, matchResult.getStartIndex(), endIndex);
        }
        return result.toString();
    }

    private static void replaceBetween(StringBuffer buffer, int startIndex, int endIndex) {
        for (int i = startIndex; i < endIndex; i++) {
            buffer.setCharAt(i, mask_char);
        }
    }

    /**
     * 加载敏感词列表
     *
     * @param words 敏感词数组
     */
    public void loadWord(List<String> words) {
        if (words == null) return;
        ac_trie = new ACTrie(words);
    }

}

AC-PRO

ACProTrie

public class ACProTrie {

    private final static char MASK = '*'; // 替代字符

    private Word root;

    // 节点
    static class Word {
        // 判断是否是敏感词结尾
        boolean end = false;
        // 失败回调节点/状态
        Word failOver = null;
        // 记录字符偏移
        int depth = 0;
        // 下个自动机状态
        Map<Character, Word> next = new HashMap<>();

        public boolean hasChild(char c) {
            return next.containsKey(c);
        }
    }

    // 构建ACTrie
    public void createACTrie(List<String> list) {
        Word currentNode = new Word();
        root = currentNode;
        for (String key : list) {
            currentNode = root;
            for (int j = 0; j < key.length(); j++) {
                if (currentNode.next != null && currentNode.next.containsKey(key.charAt(j))) {
                    currentNode = currentNode.next.get(key.charAt(j));
                    // 防止乱序输入改变end,比如da,dadac,dadac先进入,第二个a为false,da进入后把a设置为true
                    // 这样结果就是a是end,c也是end
                    if (j == key.length() - 1) {
                        currentNode.end = true;
                    }
                } else {
                    Word map = new Word();
                    if (j == key.length() - 1) {
                        map.end = true;
                    }
                    currentNode.next.put(key.charAt(j), map);
                    currentNode = map;
                }
                currentNode.depth = j + 1;
            }
        }
        initFailOver();
    }

    // 初始化匹配失败回调节点/状态
    public void initFailOver() {
        Queue<Word> queue = new LinkedList<>();
        Map<Character, Word> children = root.next;
        for (Word node : children.values()) {
            node.failOver = root;
            queue.offer(node);
        }
        while (!queue.isEmpty()) {
            Word parentNode = queue.poll();
            for (Map.Entry<Character, Word> entry : parentNode.next.entrySet()) {
                Word childNode = entry.getValue();
                Word failOver = parentNode.failOver;
                while (failOver != null && (!failOver.next.containsKey(entry.getKey()))) {
                    failOver = failOver.failOver;
                }
                if (failOver == null) {
                    childNode.failOver = root;
                } else {
                    childNode.failOver = failOver.next.get(entry.getKey());
                }
                queue.offer(childNode);
            }
        }
    }

    // 匹配
    public String match(String matchWord) {
        Word walkNode = root;
        char[] wordArray = matchWord.toCharArray();
        for (int i = 0; i < wordArray.length; i++) {
            // 失败"回溯"
            while (!walkNode.hasChild(wordArray[i]) && walkNode.failOver != null) {
                walkNode = walkNode.failOver;
            }
            if (walkNode.hasChild(wordArray[i])) {
                walkNode = walkNode.next.get(wordArray[i]);
                if (walkNode.end) {
                    // sentinelA和sentinelB作为哨兵节点,去后面探测是否仍存在end
                    Word sentinelA = walkNode; // 记录当前节点
                    Word sentinelB = walkNode; // 记录end节点
                    int k = i + 1;
                    boolean flag = false;
                    // 判断end是不是最终end即敏感词是否存在包含关系(abc,abcd)
                    while (k < wordArray.length && sentinelA.hasChild(wordArray[k])) {
                        sentinelA = sentinelA.next.get(wordArray[k]);
                        k++;
                        if (sentinelA.end) {
                            sentinelB = sentinelA;
                            flag = true;
                        }
                    }
                    // 根据结果去替换*
                    // 计算替换长度
                    int len = flag ? sentinelB.depth : walkNode.depth;
                    while (len > 0) {
                        len--;
                        int index = flag ? i - walkNode.depth + 1 + len : i - len;
                        wordArray[index] = MASK;
                    }
                    // 更新i
                    i += flag ? sentinelB.depth : 0;
                    // 更新node
                    walkNode = flag ? sentinelB.failOver : walkNode.failOver;
                }
            }
        }
        return new String(wordArray);
    }
}

ACProFilter

/**
 *@description:  基于ACFilter的优化增强版本
 */
public class ACProFilter implements SensitiveWordFilter{

    private ACProTrie acProTrie;

    @Override
    public boolean hasSensitiveWord(String text) {
        if(StringUtils.isBlank(text)) return false;
        return !Objects.equals(filter(text),text);
    }

    @Override
    public String filter(String text) {
        return acProTrie.match(text);
    }

    @Override
    public void loadWord(List<String> words) {
        if (words == null) return;
        acProTrie = new ACProTrie();
        acProTrie.createACTrie(words);
    }
}

DFA

/**
 * 敏感词工具类
 */
public final class DFAFilter implements SensitiveWordFilter {

    private DFAFilter() {
    }

    private static Word root = new Word(' '); // 敏感词字典的根节点
    private final static char replace = '*'; // 替代字符
    private final static String skipChars = " !*-+_=,,.@;:;:。、??()()【】[]《》<>“”\"‘’"; // 遇到这些字符就会跳过
    private final static Set<Character> skipSet = new HashSet<>(); // 遇到这些字符就会跳过

    static {
        for (char c : skipChars.toCharArray()) {
            skipSet.add(c);
        }
    }

    public static DFAFilter getInstance() {
        return new DFAFilter();
    }


    /**
     * 判断文本中是否存在敏感词
     *
     * @param text 文本
     * @return true: 存在敏感词, false: 不存在敏感词
     */
    public boolean hasSensitiveWord(String text) {
        if (StringUtils.isBlank(text)) return false;
        return !Objects.equals(filter(text), text);
    }

    /**
     * 敏感词替换
     *
     * @param text 待替换文本
     * @return 替换后的文本
     */
    public String filter(String text) {
        StringBuilder result = new StringBuilder(text);
        int index = 0;
        while (index < result.length()) {
            char c = result.charAt(index);
            if (skip(c)) {
                index++;
                continue;
            }
            Word word = root;
            int start = index;
            boolean found = false;
            for (int i = index; i < result.length(); i++) {
                c = result.charAt(i);
                if (skip(c)) {
                    continue;
                }
                if (c >= 'A' && c <= 'Z') {
                    c += 32;
                }
                word = word.next.get(c);
                if (word == null) {
                    break;
                }
                if (word.end) {
                    found = true;
                    for (int j = start; j <= i; j++) {
                        result.setCharAt(j, replace);
                    }
                    index = i;
                }
            }
            if (!found) {
                index++;
            }
        }
        return result.toString();
    }


    /**
     * 加载敏感词列表
     *
     * @param words 敏感词数组
     */
    public void loadWord(List<String> words) {
        if (!CollectionUtils.isEmpty(words)) {
            Word newRoot = new Word(' ');
            words.forEach(word -> loadWord(word, newRoot));
            root = newRoot;
        }
    }

    /**
     * 加载敏感词
     *
     * @param word 词
     */
    public void loadWord(String word, Word root) {
        if (StringUtils.isBlank(word)) {
            return;
        }
        Word current = root;
        for (int i = 0; i < word.length(); i++) {
            char c = word.charAt(i);
            // 如果是大写字母, 转换为小写
            if (c >= 'A' && c <= 'Z') {
                c += 32;
            }
            if (skip(c)) {
                continue;
            }
            Word next = current.next.get(c);
            if (next == null) {
                next = new Word(c);
                current.next.put(c, next);
            }
            current = next;
        }
        current.end = true;
    }


    /**
     * 从文本文件中加载敏感词列表
     *
     * @param path 文本文件的绝对路径
     */
    public void loadWordFromFile(String path) {
        try (InputStream inputStream = Files.newInputStream(Paths.get(path))) {
            loadWord(inputStream);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 从流中加载敏感词列表
     *
     * @param inputStream 文本文件输入流
     * @throws IOException IO异常
     */
    public void loadWord(InputStream inputStream) throws IOException {
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
            String line;
            ArrayList<String> list = new ArrayList<>();
            while ((line = reader.readLine()) != null) {
                list.add(line);
            }
            loadWord(list);
        }
    }

    /**
     * 判断是否需要跳过当前字符
     *
     * @param c 待检测字符
     * @return true: 需要跳过, false: 不需要跳过
     */
    private boolean skip(char c) {
        return skipSet.contains(c);
    }

    /**
     * 敏感词类
     */
    private static class Word {
        // 当前字符
        private final char c;

        // 结束标识
        private boolean end;

        // 下一层级的敏感词字典
        private Map<Character, Word> next;

        public Word(char c) {
            this.c = c;
            this.next = new HashMap<>();
        }
    }
}