一文详解高斯混合模型(GMM)在图像处理中的应用(附代码)

简介:

  一. 概述

高斯混合模型(GMM)在图像分割、对象识别、视频分析等方面均有应用,对于任意给定的数据样本集合,根据其分布概率, 可以计算每个样本数据向量的概率分布,从而根据概率分布对其进行分类,但是这些概率分布是混合在一起的,要从中分离出单个样本的概率分布就实现了样本数据聚类,而概率分布描述我们可以使用高斯函数实现,这个就是高斯混合模型-GMM。

一文详解高斯混合模型(GMM)在图像处理中的应用(附代码)

这种方法也称为D-EM即基于距离的期望最大化。

  二. 算法步骤

    1. 初始化变量定义-指定的聚类数目K与数据维度D

    2. 初始化均值、协方差、先验概率分布

    3. 迭代E-M步骤

         - E步计算期望

         - M步更新均值、协方差、先验概率分布

         -检测是否达到停止条件(最大迭代次数与最小误差满足),达到则退出迭代,否则继续E-M步骤

    4. 打印最终分类结果

  三. 代码实现


package com.gloomyfish.image.gmm;  
  
import java.util.ArrayList;  
import java.util.Arrays;  
import java.util.List;  
  
/** 
 *  
 * @author gloomy fish 
 * 
 */  
public class GMMProcessor {  
    public final static double MIN_VAR = 1E-10;  
    public static double[] samples = new double[]{10, 9, 4, 23, 13, 16, 5, 90, 100, 80, 55, 67, 8, 93, 47, 86, 3};  
    private int dimNum;  
    private int mixNum;  
    private double[] weights;  
    private double[][] m_means;  
    private double[][] m_vars;  
    private double[] m_minVars;  
  
    /*** 
     *  
     * @param m_dimNum - 每个样本数据的维度, 对于图像每个像素点来说是RGB三个向量 
     * @param m_mixNum - 需要分割为几个部分,即高斯混合模型中高斯模型的个数 
     */  
    public GMMProcessor(int m_dimNum, int m_mixNum) {  
        dimNum = m_dimNum;  
        mixNum = m_mixNum;  
        weights = new double[mixNum];  
        m_means = new double[mixNum][dimNum];  
        m_vars = new double[mixNum][dimNum];  
        m_minVars = new double[dimNum];  
    }  
      
    /*** 
     * data - 需要处理的数据 
     * @param data 
     */  
    public void process(double[] data) {  
        int m_maxIterNum = 100;  
        double err = 0.001;  
          
        boolean loop = true;  
        double iterNum = 0;  
        double lastL = 0;  
        double currL = 0;  
        int unchanged = 0;  
          
        initParameters(data);  
          
        int size = data.length;  
        double[] x = new double[dimNum];  
        double[][] next_means = new double[mixNum][dimNum];  
        double[] next_weights = new double[mixNum];  
        double[][] next_vars = new double[mixNum][dimNum];  
        List<DataNode> cList = new ArrayList<DataNode>();  
  
        while(loop) {  
            Arrays.fill(next_weights, 0);  
            cList.clear();  
            for(int i=0; i<mixNum; i++) {  
                Arrays.fill(next_means[i], 0);  
                Arrays.fill(next_vars[i], 0);  
            }  
              
            lastL = currL;  
            currL = 0;  
            for (int k = 0; k < size; k++)  
            {  
                for(int j=0;j<dimNum;j++)  
                    x[j]=data[k*dimNum+j];  
                double p = getProbability(x); // 总的概率密度分布  
                DataNode dn = new DataNode(x);  
                dn.index = k;  
                cList.add(dn);  
                double maxp = 0;  
                for (int j = 0; j < mixNum; j++)  
                {  
                    double pj = getProbability(x, j) * weights[j] / p; // 每个分类的概率密度分布百分比  
                    if(maxp < pj) {  
                        maxp = pj;  
                        dn.cindex = j;  
                    }  
      
                    next_weights[j] += pj; // 得到后验概率  
      
                    for (int d = 0; d < dimNum; d++)  
                    {  
                        next_means[j][d] += pj * x[d];  
                        next_vars[j][d] += pj* x[d] * x[d];  
                    }  
                }  
      
                currL += (p > 1E-20) ? Math.log10(p) : -20;  
            }  
            currL /= size;  
              
            // Re-estimation: generate new weight, means and variances.  
            for (int j = 0; j < mixNum; j++)  
            {  
                weights[j] = next_weights[j] / size;  
      
                if (weights[j] > 0)  
                {  
                    for (int d = 0; d < dimNum; d++)  
                    {  
                        m_means[j][d] = next_means[j][d] / next_weights[j];  
                        m_vars[j][d] = next_vars[j][d] / next_weights[j] - m_means[j][d] * m_means[j][d];  
                        if (m_vars[j][d] < m_minVars[d])  
                        {  
                            m_vars[j][d] = m_minVars[d];  
                        }  
                    }  
                }  
            }  
              
            // Terminal conditions  
            iterNum++;  
            if (Math.abs(currL - lastL) < err * Math.abs(lastL))  
            {  
                unchanged++;  
            }  
            if (iterNum >= m_maxIterNum || unchanged >= 3)  
            {  
                loop = false;  
            }  
        }  
          
        // print result  
        System.out.println("=================最终结果=================");  
        for(int i=0; i<mixNum; i++) {  
            for(int k=0; k<dimNum; k++) {  
                System.out.println("[" + i + "]: ");  
                System.out.println("means : " + m_means[i][k]);  
                System.out.println("var : " + m_vars[i][k]);  
                System.out.println();  
            }  
        }  
          
          
        // 获取分类  
        for(int i=0; i<size; i++) {  
            System.out.println("data[" + i + "]=" + data[i] + " cindex : " + cList.get(i).cindex);  
        }  
          
    }  
      
    /** 
     *  
     * @param data 
     */  
    private void initParameters(double[] data) {  
        // 随机方法初始化均值  
        int size = data.length;  
        for (int i = 0; i < mixNum; i++)  
        {  
            for (int d = 0; d < dimNum; d++)  
            {  
                m_means[i][d] = data[(int)(Math.random()*size)];  
            }  
        }  
          
        // 根据均值获取分类  
        int[] types = new int[size];  
        for (int k = 0; k < size; k++)  
        {  
            double max = 0;  
            for (int i = 0; i < mixNum; i++)  
            {  
                double v = 0;  
                for(int j=0;j<dimNum;j++) {  
                    v += Math.abs(data[k*dimNum+j] - m_means[i][j]);  
                }  
                if(v > max) {  
                    max = v;  
                    types[k] = i;  
                }  
            }  
        }  
        double[] counts = new double[mixNum];  
        for(int i=0; i<types.length; i++) {  
            counts[types[i]]++;  
        }  
          
        // 计算先验概率权重  
        for (int i = 0; i < mixNum; i++)  
        {  
            weights[i] = counts[i] / size;  
        }  
          
        // 计算每个分类的方差  
        int label = -1;  
        int[] Label = new int[size];  
        double[] overMeans = new double[dimNum];  
        double[] x = new double[dimNum];  
        for (int i = 0; i < size; i++)  
        {  
            for(int j=0;j<dimNum;j++)  
                x[j]=data[i*dimNum+j];  
            label=Label[i];  
  
            // Count each Gaussian  
            counts[label]++;  
            for (int d = 0; d < dimNum; d++)  
            {  
                m_vars[label][d] += (x[d] - m_means[types[i]][d]) * (x[d] - m_means[types[i]][d]);  
            }  
  
            // Count the overall mean and variance.  
            for (int d = 0; d < dimNum; d++)  
            {  
                overMeans[d] += x[d];  
                m_minVars[d] += x[d] * x[d];  
            }  
        }  
  
        // Compute the overall variance (* 0.01) as the minimum variance.  
        for (int d = 0; d < dimNum; d++)  
        {  
            overMeans[d] /= size;  
            m_minVars[d] = Math.max(MIN_VAR, 0.01 * (m_minVars[d] / size - overMeans[d] * overMeans[d]));  
        }  
  
        // Initialize each Gaussian.  
        for (int i = 0; i < mixNum; i++)  
        {  
  
            if (weights[i] > 0)  
            {  
                for (int d = 0; d < dimNum; d++)  
                {  
                    m_vars[i][d] = m_vars[i][d] / counts[i];  
  
                    // A minimum variance for each dimension is required.  
                    if (m_vars[i][d] < m_minVars[d])  
                    {  
                        m_vars[i][d] = m_minVars[d];  
                    }  
                }  
            }  
        }  
          
        System.out.println("=================初始化=================");  
        for(int i=0; i<mixNum; i++) {  
            for(int k=0; k<dimNum; k++) {  
                System.out.println("[" + i + "]: ");  
                System.out.println("means : " + m_means[i][k]);  
                System.out.println("var : " + m_vars[i][k]);  
                System.out.println();  
            }  
        }  
          
    }  
  
    /*** 
     *  
     * @param sample - 采样数据点 
     * @return 该点总概率密度分布可能性 
     */  
    public double getProbability(double[] sample)  
    {  
        double p = 0;  
        for (int i = 0; i < mixNum; i++)  
        {  
            p += weights[i] * getProbability(sample, i);  
        }  
        return p;  
    }  
  
    /** 
     * Gaussian Model -> PDF 
     * @param x - 表示采样数据点向量 
     * @param j - 表示对对应的第J个分类的概率密度分布 
     * @return - 返回概率密度分布可能性值 
     */  
    public double getProbability(double[] x, int j)  
    {  
        double p = 1;  
        for (int d = 0; d < dimNum; d++)  
        {  
            p *= 1 / Math.sqrt(2 * 3.14159 * m_vars[j][d]);  
            p *= Math.exp(-0.5 * (x[d] - m_means[j][d]) * (x[d] - m_means[j][d]) / m_vars[j][d]);  
        }  
        return p;  
    }  
      
    public static void main(String[] args) {  
        GMMProcessor filter = new GMMProcessor(1, 2);  
        filter.process(samples);  
          
    }  
}  


结构类DataNode


package com.gloomyfish.image.gmm;  
  
public class DataNode {  
    public int cindex; // cluster  
    public int index;  
    public double[] value;  
      
    public DataNode(double[] v) {  
        this.value = v;  
        cindex = -1;  
        index = -1;  
    }  
}  


  四. 结果

一文详解高斯混合模型(GMM)在图像处理中的应用(附代码)

这里初始中心均值的方法我是通过随机数来实现,GMM算法运行结果跟初始化有很大关系,常见初始化中心点的方法是通过K-Means来计算出中心点。大家可以尝试修改代码基于K-Means初始化参数,我之所以选择随机参数初始,主要是为了省事!




====================================分割线================================

本文作者:AI研习社
本文转自雷锋网禁止二次转载, 原文链接
目录
相关文章
|
3月前
|
机器学习/深度学习 自然语言处理 算法
【模式识别】探秘判别奥秘:Fisher线性判别算法的解密与实战
【模式识别】探秘判别奥秘:Fisher线性判别算法的解密与实战
81 0
|
5月前
|
机器学习/深度学习 Python
LSTM神经网络实现对股市收盘价格的预测实战(python实现 附源码 超详细)
LSTM神经网络实现对股市收盘价格的预测实战(python实现 附源码 超详细)
68 0
|
6月前
|
机器学习/深度学习 存储 算法
机器学习面试笔试知识点之非监督学习-K 均值聚类、高斯混合模型(GMM)、自组织映射神经网络(SOM)
机器学习面试笔试知识点之非监督学习-K 均值聚类、高斯混合模型(GMM)、自组织映射神经网络(SOM)
48 0
|
3天前
|
存储 机器学习/深度学习 算法
R语言贝叶斯Metropolis-Hastings采样 MCMC算法理解和应用可视化案例
R语言贝叶斯Metropolis-Hastings采样 MCMC算法理解和应用可视化案例
|
9天前
|
数据可视化 数据挖掘 TensorFlow
Python贝叶斯高斯混合模型GMM聚类分析数据和混合密度可视化
Python贝叶斯高斯混合模型GMM聚类分析数据和混合密度可视化
|
3月前
|
API Python
gee python:利用核函数对影像进行平滑处理和边缘提取分析
gee python:利用核函数对影像进行平滑处理和边缘提取分析
22 0
|
3月前
|
机器学习/深度学习 算法 数据挖掘
【模式识别】探秘聚类奥秘:K-均值聚类算法解密与实战
【模式识别】探秘聚类奥秘:K-均值聚类算法解密与实战
51 1
|
4月前
Halcon 学习笔记九:mlp训练学习实战颜色识别
Halcon 学习笔记九:mlp训练学习实战颜色识别
51 0
|
传感器 机器学习/深度学习 人工智能
史上最全综述 | 3D目标检测算法汇总!(单目/双目/LiDAR/多模态/时序/半弱自监督)(下)
近年来,自动驾驶因其减轻驾驶员负担、提高行车安全的潜力而受到越来越多的关注。在现代自动驾驶系统中,感知系统是不可或缺的组成部分,旨在准确估计周围环境的状态,并为预测和规划提供可靠的观察结果。3D目标检测可以智能地预测自动驾驶车辆附近关键3D目标的位置、大小和类别,是感知系统的重要组成部分。本文回顾了应用于自动驾驶领域的3D目标检测的进展。
史上最全综述 | 3D目标检测算法汇总!(单目/双目/LiDAR/多模态/时序/半弱自监督)(下)
|
9月前
|
机器学习/深度学习 决策智能 计算机视觉
计算机视觉实战(八)直方图与傅里叶变换
计算机视觉实战(八)直方图与傅里叶变换