java 聚类_聚类算法 java实现
发布日期:2021-08-19 23:52:21 浏览次数:1 分类:技术文章

本文共 11563 字,大约阅读时间需要 38 分钟。

package org.test;

import java.util.ArrayList;

import org.algorithm.Kmeans;

public class KmeansTest {

public static void main(String[] args)

{

//初始化一个Kmean对象,将k置为10

Kmeans k=new Kmeans(10);

ArrayList dataSet=new ArrayList();

dataSet.add(new float[]{1,2});

dataSet.add(new float[]{3,3});

dataSet.add(new float[]{3,4});

dataSet.add(new float[]{5,6});

dataSet.add(new float[]{8,9});

dataSet.add(new float[]{4,5});

dataSet.add(new float[]{6,4});

dataSet.add(new float[]{3,9});

dataSet.add(new float[]{5,9});

dataSet.add(new float[]{4,2});

dataSet.add(new float[]{1,9});

dataSet.add(new float[]{7,8});

//设置原始数据集

k.setDataSet(dataSet);

//执行算法

k.execute();

//得到聚类结果

ArrayList> cluster=k.getCluster();

//查看结果

for(int i=0;i

{

k.printDataArray(cluster.get(i), "cluster["+i+"]");

}

}

}

package org.algorithm;

import java.util.ArrayList;

import java.util.Random;

public class Kmeans {

private int k;// 分成多少簇

private int m;// 迭代次数

private int dataSetLength;// 数据集元素个数,即数据集的长度

private ArrayList dataSet;// 数据集链表

private ArrayList center;// 中心链表

private ArrayList> cluster; // 簇

private ArrayList jc;// 误差平方和,k越接近dataSetLength,误差越小

private Random random;

public void setDataSet(ArrayList dataSet) {

this.dataSet = dataSet;

}

public ArrayList> getCluster() {

return cluster;

}

public Kmeans(int k) {

if (k <= 0) {

k = 1;

}

this.k = k;

}

private void init() {

m = 0;

random = new Random();

if (dataSet == null || dataSet.size() == 0) {

initDataSet();

}

dataSetLength = dataSet.size();

if (k > dataSetLength) {

k = dataSetLength;

}

center = initCenters();

cluster = initCluster();

jc = new ArrayList();

}

private void initDataSet() {

dataSet = new ArrayList();

// 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0

float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, {

2, 5 },

{ 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },

{ 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 }

};

for (int i = 0; i < dataSetArray.length; i++) {

dataSet.add(dataSetArray[i]);

}

}

private ArrayList initCenters() {

ArrayList center = new ArrayList();

int[] randoms = new int[k];

boolean flag;

int temp = random.nextInt(dataSetLength);

randoms[0] = temp;

for (int i = 1; i < k; i++) {

flag = true;

while (flag) {

temp = random.nextInt(dataSetLength);

int j = 0;

// 不清楚for循环导致j无法加1

// for(j=0;j

// {

// if(temp==randoms[j]);

// {

// break;

// }

// }

while (j < i) {

if (temp == randoms[j]) {

break;

}

j++;

}

if (j == i) {

flag = false;

}

}

randoms[i] = temp;

}

// 测试随机数生成情况

// for(int i=0;i

// {

// System.out.println("test1:randoms["+i+"]="+randoms[i]);

// }

// System.out.println();

for (int i = 0; i < k; i++) {

center.add(dataSet.get(randoms[i]));// 生成初始化中心链表

}

return center;

}

private ArrayList> initCluster() {

ArrayList> cluster = new ArrayList>();

for (int i = 0; i < k; i++) {

cluster.add(new ArrayList());

}

return cluster;

}

private float distance(float[] element, float[] center) {

float distance = 0.0f;

float x = element[0] - center[0];

float y = element[1] - center[1];

float z = x * x + y * y;

distance = (float) Math.sqrt(z);

return distance;

}

private int minDistance(float[] distance) {

float minDistance = distance[0];

int minLocation = 0;

for (int i = 1; i < distance.length; i++) {

if (distance[i] < minDistance) {

minDistance = distance[i];

minLocation = i;

} else if (distance[i] == minDistance) // 如果相等,随机返回一个位置

{

if (random.nextInt(10) < 5) {

minLocation = i;

}

}

}

return minLocation;

}

private void clusterSet() {

float[] distance = new float[k];

for (int i = 0; i < dataSetLength; i++) {

for (int j = 0; j < k; j++) {

distance[j] = distance(dataSet.get(i), center.get(j));

//

System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);

}

int minLocation = minDistance(distance);

//

System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation);

// System.out.println();

cluster.get(minLocation).add(dataSet.get(i));//

核心,将当前元素放到最小距离中心相关的簇中

}

}

private float errorSquare(float[] element, float[] center) {

float x = element[0] - center[0];

float y = element[1] - center[1];

float errSquare = x * x + y * y;

return errSquare;

}

private void countRule() {

float jcF = 0;

for (int i = 0; i < cluster.size(); i++) {

for (int j = 0; j < cluster.get(i).size(); j++) {

jcF += errorSquare(cluster.get(i).get(j), center.get(i));

}

}

jc.add(jcF);

}

private void setNewCenter() {

for (int i = 0; i < k; i++) {

int n = cluster.get(i).size();

if (n != 0) {

float[] newCenter = { 0, 0 };

for (int j = 0; j < n; j++) {

newCenter[0] += cluster.get(i).get(j)[0];

newCenter[1] += cluster.get(i).get(j)[1];

}

// 设置一个平均值

newCenter[0] = newCenter[0] / n;

newCenter[1] = newCenter[1] / n;

center.set(i, newCenter);

}

}

}

public void printDataArray(ArrayList dataArray,

String dataArrayName) {

for (int i = 0; i < dataArray.size(); i++) {

System.out.println("print:" + dataArrayName + "[" + i +

"]={"

+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");

}

System.out.println("===================================");

}

private void kmeans() {

init();

// printDataArray(dataSet,"initDataSet");

// printDataArray(center,"initCenter");

// 循环分组,直到误差不变为止

while (true) {

clusterSet();

// for(int i=0;i

// {

// printDataArray(cluster.get(i),"cluster["+i+"]");

// }

countRule();

// System.out.println("count:"+"jc["+m+"]="+jc.get(m));

// System.out.println();

// 误差不变了,分组完成

if (m != 0) {

if (jc.get(m) - jc.get(m - 1) == 0) {

break;

}

}

setNewCenter();

// printDataArray(center,"newCenter");

m++;

cluster.clear();

cluster = initCluster();

}

// System.out.println("note:the times of

repeat:m="+m);//输出迭代次数

}

public void execute() {

long startTime = System.currentTimeMillis();

System.out.println("kmeans begins");

kmeans();

long endTime = System.currentTimeMillis();

System.out.println("kmeans running time=" + (endTime -

startTime)

+ "ms");

System.out.println("kmeans ends");

System.out.println();

}

}

package org.algorithm;

import java.util.ArrayList;

import java.util.Random;

public class Kmeans {

private int k;// 分成多少簇

private int m;// 迭代次数

private int dataSetLength;// 数据集元素个数,即数据集的长度

private ArrayList dataSet;// 数据集链表

private ArrayList center;// 中心链表

private ArrayList> cluster; // 簇

private ArrayList jc;// 误差平方和,k越接近dataSetLength,误差越小

private Random random;

public void setDataSet(ArrayList dataSet) {

this.dataSet = dataSet;

}

public ArrayList> getCluster() {

return cluster;

}

public Kmeans(int k) {

if (k <= 0) {

k = 1;

}

this.k = k;

}

private void init() {

m = 0;

random = new Random();

if (dataSet == null || dataSet.size() == 0) {

initDataSet();

}

dataSetLength = dataSet.size();

if (k > dataSetLength) {

k = dataSetLength;

}

center = initCenters();

cluster = initCluster();

jc = new ArrayList();

}

private void initDataSet() {

dataSet = new ArrayList();

// 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0

float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5

},

{ 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },

{ 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };

for (int i = 0; i < dataSetArray.length; i++) {

dataSet.add(dataSetArray[i]);

}

}

private ArrayList initCenters() {

ArrayList center = new ArrayList();

int[] randoms = new int[k];

boolean flag;

int temp = random.nextInt(dataSetLength);

randoms[0] = temp;

for (int i = 1; i < k; i++) {

flag = true;

while (flag) {

temp = random.nextInt(dataSetLength);

int j = 0;

// 不清楚for循环导致j无法加1

// for(j=0;j

// {

// if(temp==randoms[j]);

// {

// break;

// }

// }

while (j < i) {

if (temp == randoms[j]) {

break;

}

j++;

}

if (j == i) {

flag = false;

}

}

randoms[i] = temp;

}

// 测试随机数生成情况

// for(int i=0;i

// {

// System.out.println("test1:randoms["+i+"]="+randoms[i]);

// }

// System.out.println();

for (int i = 0; i < k; i++) {

center.add(dataSet.get(randoms[i]));// 生成初始化中心链表

}

return center;

}

private ArrayList> initCluster() {

ArrayList> cluster = new ArrayList>();

for (int i = 0; i < k; i++) {

cluster.add(new ArrayList());

}

return cluster;

}

private float distance(float[] element, float[] center) {

float distance = 0.0f;

float x = element[0] - center[0];

float y = element[1] - center[1];

float z = x * x + y * y;

distance = (float) Math.sqrt(z);

return distance;

}

private int minDistance(float[] distance) {

float minDistance = distance[0];

int minLocation = 0;

for (int i = 1; i < distance.length; i++) {

if (distance[i] < minDistance) {

minDistance = distance[i];

minLocation = i;

} else if (distance[i] == minDistance) // 如果相等,随机返回一个位置

{

if (random.nextInt(10) < 5) {

minLocation = i;

}

}

}

return minLocation;

}

private void clusterSet() {

float[] distance = new float[k];

for (int i = 0; i < dataSetLength; i++) {

for (int j = 0; j < k; j++) {

distance[j] = distance(dataSet.get(i), center.get(j));

//

System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]);

}

int minLocation = minDistance(distance);

//

System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation);

// System.out.println();

cluster.get(minLocation).add(dataSet.get(i));//

核心,将当前元素放到最小距离中心相关的簇中

}

}

private float errorSquare(float[] element, float[] center) {

float x = element[0] - center[0];

float y = element[1] - center[1];

float errSquare = x * x + y * y;

return errSquare;

}

private void countRule() {

float jcF = 0;

for (int i = 0; i < cluster.size(); i++) {

for (int j = 0; j < cluster.get(i).size(); j++) {

jcF += errorSquare(cluster.get(i).get(j), center.get(i));

}

}

jc.add(jcF);

}

private void setNewCenter() {

for (int i = 0; i < k; i++) {

int n = cluster.get(i).size();

if (n != 0) {

float[] newCenter = { 0, 0 };

for (int j = 0; j < n; j++) {

newCenter[0] += cluster.get(i).get(j)[0];

newCenter[1] += cluster.get(i).get(j)[1];

}

// 设置一个平均值

newCenter[0] = newCenter[0] / n;

newCenter[1] = newCenter[1] / n;

center.set(i, newCenter);

}

}

}

public void printDataArray(ArrayList dataArray,

String dataArrayName) {

for (int i = 0; i < dataArray.size(); i++) {

System.out.println("print:" + dataArrayName + "[" + i + "]={"

+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");

}

System.out.println("===================================");

}

private void kmeans() {

init();

// printDataArray(dataSet,"initDataSet");

// printDataArray(center,"initCenter");

// 循环分组,直到误差不变为止

while (true) {

clusterSet();

// for(int i=0;i

// {

// printDataArray(cluster.get(i),"cluster["+i+"]");

// }

countRule();

// System.out.println("count:"+"jc["+m+"]="+jc.get(m));

// System.out.println();

// 误差不变了,分组完成

if (m != 0) {

if (jc.get(m) - jc.get(m - 1) == 0) {

break;

}

}

setNewCenter();

// printDataArray(center,"newCenter");

m++;

cluster.clear();

cluster = initCluster();

}

// System.out.println("note:the times of repeat:m="+m);//输出迭代次数

}

public void execute() {

long startTime = System.currentTimeMillis();

System.out.println("kmeans begins");

kmeans();

long endTime = System.currentTimeMillis();

System.out.println("kmeans running time=" + (endTime -

startTime)

+ "ms");

System.out.println("kmeans ends");

System.out.println();

}

}

测试:

1.聚类分析

聚类分析是数据挖掘中的一种分析方法,它把一个没有类别标记的样本集按照一种准则划分成若干个相似的子集类,使相似的样本尽可能归为一类,不相似的划分到不同的类别中。聚类通过比较数据的相似性和差异性发现数据的内在特征和分布规律,从而获得对数据更深刻的认识。聚类分析以相似性为基础,在一个聚类中的模式之间比不在同一聚类中的模式之间具有更多的相似性。

聚类分析的算法可以分为划分法、层次法、基于密度的方法、基于网格的方法、基于模型的方法。

基于划分的聚类主要有k-平均及其变种,速度快易于实现还适用于文本图像等多种数据的聚类分析。

k-means算法

k-means 算法接受输入量 k ;然后将n个数据对象划分为

k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。

k-means 算法的工作过程说明如下:

首先从n个数据对象任意选择 k

个对象作为初始聚类中心;而对于所剩下其它对象,则根据它们与这些聚类中心的相似度(距离),分别将它们分配给与其最相似的(聚类中心所代表的)聚类;

然后再计算每个所获新聚类的聚类中心(该聚类中所有对象的均值);不断重复这一过程直到标准测度函数开始收敛为止。

一般都采用均方差作为标准测度函数.

k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。

KMeans算法的基本思想是初始随机给定K个簇中心,按照最邻近原则把待分类样本点分到各个簇。然后按平均法重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心的移动距离小于某个给定的值。

a4c26d1e5885305701be709a3d33442f.png

a4c26d1e5885305701be709a3d33442f.png

转载地址:https://blog.csdn.net/weixin_31414801/article/details/114035669 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:微信平台 java_java微信公众平台开发
下一篇:mysql optimizertrace_MySQL 调优 | OPTIMIZER_TRACE详解

发表评论

最新留言

第一次来,支持一个
[***.219.124.196]2024年03月18日 02时04分19秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章

python中倒背如流_八字基础知识--倒背如流篇 2019-04-21
以太坊地址和公钥_以太坊地址是什么 2019-04-21
linux查看wifi信号命令_linux – 获取WIFI信号强度 – 寻求最佳方式(IOCTL,iwlist(iw)等)... 2019-04-21
npm 不重启 全局安装后_解决修复npm安装全局模块权限的问题 2019-04-21
vs格式化json 不生效_vs code 格式化 json 配置 2019-04-21
go 字符串反序列化成对象数组_Fastjson 1.2.24反序列化漏洞深度分析 2019-04-21
onmessage websocket 收不到信息_WebSocket断开重连解决方案,心跳重连实践 2019-04-21
hibernate mysql 缓存_hibernate和mysql的缓存问题,没辙了! 2019-04-21
abp框架 mysql_ABP框架使用Mysql数据库 2019-04-21
mysql树形递归删除_使用递归删除树形结构的所有子节点(java和mysql实现) 2019-04-21
linux mysql 不能连接远程_linux mysql 远程连接 2019-04-21
mysql $lt_mongodb中比较级查询条件:($lt $lte $gt $gte)(大于、小于)、查找条件... 2019-04-21
install python_Install python on AIX 7 2019-04-21
jquery查找div下第一个input_jquery查找div元素第一个元素id 2019-04-21
如何修改手机屏幕显示的长宽比例_屏幕分辨率 尺寸 比例 长宽 如何计算 2019-04-21
mysql 的版本 命名规则_MySQL版本和命名规则 2019-04-21
no java stack_Java Stack contains()用法及代码示例 2019-04-21
java动态代码_Java Agent入门学习之动态修改代码 2019-04-21
python集合如何去除重复数据_Python 迭代删除重复项,集合删除重复项 2019-04-21
iview 自定义时间选择器组件_Vue.js中使用iView日期选择器并设置开始时间结束时间校验功能... 2019-04-21