sparkmllib---sgd随机梯度下降算法(代码片段)

汪本成 汪本成     2022-12-06     549

关键词:

代码:

package mllib


import org.apache.log4j.Level, Logger
import org.apache.spark.SparkContext, SparkConf

import scala.collection.mutable.HashMap

/**
  * 随机梯度下降算法
  * Created by 汪本成 on 2016/8/5.
  */
object SGD 

  //屏蔽不必要的日志显示在终端上
  Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)

  //程序入口
  val conf = new SparkConf()
    .setMaster("local[1]")
    .setAppName(this.getClass().getSimpleName()
    .filter(!_.equals('$')))
  
  println(this.getClass().getSimpleName().filter(!_.equals('$')))

  val sc = new SparkContext(conf)

  //创建存储数据集HashMap集合
  val data = new HashMap[Int, Int]()
  //生成数据集内容
  def getData(): HashMap[Int, Int] = 
    for(i <- 1 to 50) 
      data += (i -> (2 * i))  //写入公式y=2x
    
    data
  

  //假设a=0
  var a: Double = 0
  //设置步进系数
  var b: Double = 0.1

  //设置迭代公式
  def sgd(x: Double, y: Double) = 
    a = a - b * ((a * x) - y)
  

  def main(args: Array[String]) 
    //获取数据集
    val dataSource = getData()
    println("data: ")
    dataSource.foreach(each => println(each + " "))
    println("\\nresult: ")
    var num = 1
    //开始迭代
    dataSource.foreach(myMap => 
      println(num + ":" + a + "("+myMap._1+","+myMap._2+")")
      sgd(myMap._1, myMap._2)
      num = num + 1
    )
    //显示结果
    println("最终结果a " + a)
  


运行结果:

"C:\\Program Files\\Java\\jdk1.8.0_77\\bin\\java" -Didea.launcher.port=7533 "-Didea.launcher.bin.path=D:\\Program Files (x86)\\JetBrains\\IntelliJ IDEA 15.0.5\\bin" -Dfile.encoding=UTF-8 -classpath "C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\charsets.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\deploy.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\access-bridge-64.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\cldrdata.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\dnsns.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\jaccess.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\jfxrt.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\localedata.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\nashorn.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\sunec.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\sunjce_provider.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\sunmscapi.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\sunpkcs11.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\ext\\zipfs.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\javaws.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\jce.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\jfr.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\jfxswt.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\jsse.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\management-agent.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\plugin.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\resources.jar;C:\\Program Files\\Java\\jdk1.8.0_77\\jre\\lib\\rt.jar;G:\\location\\spark-mllib\\out\\production\\spark-mllib;C:\\Program Files (x86)\\scala\\lib\\scala-actors-migration.jar;C:\\Program Files (x86)\\scala\\lib\\scala-actors.jar;C:\\Program Files (x86)\\scala\\lib\\scala-library.jar;C:\\Program Files (x86)\\scala\\lib\\scala-reflect.jar;C:\\Program Files (x86)\\scala\\lib\\scala-swing.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\datanucleus-api-jdo-3.2.6.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\datanucleus-core-3.2.10.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\datanucleus-rdbms-3.2.9.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\spark-1.6.1-yarn-shuffle.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\spark-assembly-1.6.1-hadoop2.6.0.jar;G:\\home\\download\\spark-1.6.1-bin-hadoop2.6\\lib\\spark-examples-1.6.1-hadoop2.6.0.jar;D:\\Program Files (x86)\\JetBrains\\IntelliJ IDEA 15.0.5\\lib\\idea_rt.jar" com.intellij.rt.execution.application.AppMain mllib.SGD
SGD
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/G:/home/download/spark-1.6.1-bin-hadoop2.6/lib/spark-assembly-1.6.1-hadoop2.6.0.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/G:/home/download/spark-1.6.1-bin-hadoop2.6/lib/spark-examples-1.6.1-hadoop2.6.0.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
16/08/05 00:48:28 INFO Slf4jLogger: Slf4jLogger started
16/08/05 00:48:28 INFO Remoting: Starting remoting
16/08/05 00:48:28 INFO Remoting: Remoting started; listening on addresses :[akka.tcp://sparkDriverActorSystem@192.168.43.1:24009]
data: 
(23,46) 
(50,100) 
(32,64) 
(41,82) 
(17,34) 
(8,16) 
(35,70) 
(44,88) 
(26,52) 
(11,22) 
(29,58) 
(38,76) 
(47,94) 
(20,40) 
(2,4) 
(5,10) 
(14,28) 
(46,92) 
(40,80) 
(49,98) 
(4,8) 
(13,26) 
(22,44) 
(31,62) 
(16,32) 
(7,14) 
(43,86) 
(25,50) 
(34,68) 
(10,20) 
(37,74) 
(1,2) 
(19,38) 
(28,56) 
(45,90) 
(27,54) 
(36,72) 
(18,36) 
(9,18) 
(21,42) 
(48,96) 
(3,6) 
(12,24) 
(30,60) 
(39,78) 
(15,30) 
(42,84) 
(24,48) 
(6,12) 
(33,66) 

result: 
1:0.0(23,46)
2:4.6000000000000005(50,100)
3:-8.400000000000002(32,64)
4:24.880000000000006(41,82)
5:-68.92800000000003(17,34)
6:51.649600000000035(8,16)
7:11.929920000000003(35,70)
8:-22.82480000000001(44,88)
9:86.40432000000006(26,52)
10:-133.04691200000013(11,22)
11:15.504691199999996(29,58)
12:-23.65891328(38,76)
13:73.84495718400001(47,94)
14:-263.82634158080003(20,40)
15:267.82634158080003(2,4)
16:214.66107326464004(5,10)
17:108.33053663232002(14,28)
18:-40.53221465292802(46,92)
19:155.1159727505409(40,80)
20:-457.3479182516227(49,98)
21:1793.4568811813288(4,8)
22:1076.8741287087973(13,26)
23:-320.46223861263934(22,44)
24:388.95468633516725(31,62)
25:-810.6048413038511(16,32)
26:489.56290478231085(7,14)
27:148.2688714346932(43,86)
28:-480.6872757344877(25,50)
29:726.0309136017315(34,68)
30:-1735.6741926441557(10,20)
31:2.0000000000002274(37,74)
32:1.999999999999386(1,2)
33:1.9999999999994476(19,38)
34:2.000000000000497(28,56)
35:1.9999999999991056(45,90)
36:2.00000000000313(27,54)
37:1.9999999999946787(36,72)
38:2.000000000013835(18,36)
39:1.999999999988932(9,18)
40:1.999999999998893(21,42)
41:2.0000000000012172(48,96)
42:1.9999999999953737(3,6)
43:1.9999999999967615(12,24)
44:2.000000000000648(30,60)
45:1.999999999998704(39,78)
46:2.0000000000037588(15,30)
47:1.9999999999981206(42,84)
48:2.0000000000060134(24,48)
49:1.999999999991581(6,12)
50:1.9999999999966325(33,66)
最终结果a为 2.0000000000077454
16/08/05 00:48:28 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon.

Process finished with exit code 0
分析:
当α为0.1的时候,一般30次计算就计算出来了;如果是0.5,一般15次计算就有正确结果 。如果是1,则50次都没有结果

梯度下降随机梯度下降批量梯度下降

...点处沿着该方向变化最快,变化率最大(为该梯度的模)随机梯度下降(SGD):每次迭代随机使用一组样本针对BGD算法训练速度过慢的缺点,提出了SGD算法,普通的BGD算法是每次迭代把所有样本都过一遍,每训练一组样本就把梯... 查看详情

神经网络与深度学习:梯度下降算法和随机梯度下降算法

本文总结自《NeuralNetworksandDeepLearning》第1章的部分内容。  使用梯度下降算法进行学习(Learningwithgradientdescent)1.目标我们希望有一个算法,能让我们找到权重和偏置,以至于网络的输出y(x)能够拟合所有的训练输入x。2.代... 查看详情

机器学习梯度下降算法

...法的推导流程2梯度下降法大家族2.1全梯度下降算法(FG)2.2随机梯度下降算法(SG)2.3小批量梯度下降算法2.4随机平均梯度下降算法(SAG)3小结1详解梯度下降算法1.1梯度下降的相关概念复习在详细了解梯度下降的算法之前,我们先复... 查看详情

如何在tensorflow中实现多元线性随机梯度下降算法?

】如何在tensorflow中实现多元线性随机梯度下降算法?【英文标题】:Howtoimplementmultivariatelinearstochasticgradientdescentalgorithmintensorflow?【发布时间】:2016-07-0201:49:20【问题描述】:我从单变量线性梯度下降的简单实现开始,但不知道... 查看详情

机器学习算法(优化)之一:梯度下降算法随机梯度下降(应用于线性回归logistic回归等等)

本文介绍了机器学习中基本的优化算法—梯度下降算法和随机梯度下降算法,以及实际应用到线性回归、Logistic回归、矩阵分解推荐算法等ML中。梯度下降算法基本公式常见的符号说明和损失函数X :所有样本的特征向量组成的... 查看详情

梯度下降算法vs正规方程算法

...N个样本的梯度数据  优点:准确  缺点:速度慢②随机梯度下降:和批量梯度下降算法原理相似,区别在于求梯度时没有用所有的N歌样本数据,而是仅仅选取1个来求梯度  优点:速度快  缺点:准去率地③小批量梯度... 查看详情

梯度下降法和随机梯度下降法的区别

梯度下降和随机梯度下降之间的关键区别:  1、标准梯度下降是在权值更新前对所有样例汇总误差,而随机梯度下降的权值是通过考查某个训练样例来更新的。  2、在标准梯度下降中,权值更新的每一步对多个样例求和,... 查看详情

梯度下降法介绍-黑马程序员机器学习讲义

学习目标知道全梯度下降算法的原理知道随机梯度下降算法的原理知道随机平均梯度下降算法的原理知道小批量梯度下降算法的原理上一节中给大家介绍了最基本的梯度下降法实现流程,常见的梯度下降算法有:全梯度... 查看详情

机器学习梯度下降法(超详解)

...性回归-梯度下降法前言1.全梯度下降算法(FG)2.随机梯度下降算法(SG)3.小批量梯度下降算法(mini-batch)4.随机平均梯度下降算法(SAG)5.梯度下降法算法比较和进一步优化5.1算法比较5.2梯度下降... 查看详情

监督学习:随机梯度下降算法(sgd)和批梯度下降算法(bgd)

线性回归      首先要明白什么是回归。回归的目的是通过几个已知数据来预测另一个数值型数据的目标值。     假设特征和结果满足线性关系,即满足一个计算公式h(x),这个公式的自变... 查看详情

随机梯度下降实现 - MATLAB

】随机梯度下降实现-MATLAB【英文标题】:StochasticgradientDescentimplementation-MATLAB【发布时间】:2011-07-0407:00:27【问题描述】:我正在尝试在MATLAB中实现“Stochasticgradientdescent”。我完全遵循了算法,但我得到了一个非常非常大的w(... 查看详情

随机梯度下降算法求解svm

 测试代码(matlab)如下:clear;loadE:datasetUSPSUSPS.mat;%dataformat:%Xtrn1*dim%Xten2*dim%Ytrn1*1%Yten2*1%warning:labelsmustrangefrom1ton,nisthenumberoflabels%otherlabelvalueswillmakemistakesu=unique(Y 查看详情

随机梯度下降算法及其注释

#-*-coding:utf8-*-importmathimportmatplotlib.pyplotaspltdeff(w,x):N=len(w)i=0y=0whilei<N-1:y+=w[i]*x[i]i+=1y+=w[N-1]#常数项returnydefgradient(data,w,j):M=len(data)#样本数N=len(data[0])i=0g=0#当前维度的梯度while 查看详情

随机梯度下降分类器和回归器

随机梯度下降分类器并不是一个独立的算法,而是一系列利用随机梯度下降求解参数的算法的集合。SGDClassifier(分类):fromsklearn.linear_modelimportSGDClassifierclf=SGDClassifier(loss="hinge",penalty="l2") lossfunction(损失函数):可以通过loss... 查看详情

基于numpy实现随机梯度下降算法(代码片段)

...吧!本文较长,信息量较大,建议收藏!随机梯度下降在机器学习应用中被广泛使用。与反向传播相结合, 查看详情

牛客网算法八股刷题系列卷积函数随机梯度下降relu(代码片段)

牛客网算法八股刷题系列——卷积函数、随机梯度下降、ReLU题目描述正确答案:B\\mathcalBB题目解析题目描述本节并不过多针对题目中的非线性,而更多关注随机梯度下降、卷积运算自身以及卷积运算与全连接运算在动机... 查看详情

感知机2--随机梯度下降算法

声明:        1,本篇为个人对《2012.李航.统计学习方法.pdf》的学习总结。不得用作商用,欢迎转载,但请注明出处(即:本帖地址)。        2,因为本人在学习初始时... 查看详情

随机梯度下降法(stochasticgradientdescent,sgd)

...小)  Mold一直在更新 SGD(Stochasticgradientdescent)随机梯度下降法:每次迭代使用一组样本(样本量大)Mold把一批数据过完才更新一次 针对BGD算法训练速度过慢的缺点,提出了SGD算法,普通的BGD算法是每次迭代把所... 查看详情