diff --git a/src/main/scala/collaborativeFiltering/MatrixFactorization.scala b/src/main/scala/collaborativeFiltering/MatrixFactorization.scala index 4cbdeda..abb7552 100644 --- a/src/main/scala/collaborativeFiltering/MatrixFactorization.scala +++ b/src/main/scala/collaborativeFiltering/MatrixFactorization.scala @@ -113,9 +113,9 @@ class MatrixFactorization extends Serializable{ val testTimeStart = System.currentTimeMillis() val bc_test_itemFactors = ratingsByRow.context.broadcast(itemFactors) //training loss - val loss = ratingsByRow.mapPartitions {iter => + val loss = ratingsByRow.mapPartitionsWithIndex {(index,iter) => val localV = bc_test_itemFactors.value - val localU = MatrixFactorization.workerstore.get[Map[Int, Vector]]("userFactors") + val localU = MatrixFactorization.workerstore.get[Map[Int, Vector]](s"userFactors_$index") val reguV = localV.mapValues(v => lambda_v * v.dot(v)) val reguU = localU.mapValues(u => lambda_u * u.dot(u)) val ls = iter.foldLeft(0.0) { (l, r) => @@ -125,7 +125,7 @@ class MatrixFactorization extends Serializable{ l + residual * residual + reguU.get(r.index_x).get + reguV.get(r.index_y).get } Iterator.single(ls) - }.reduce(_ + _) / numRatings + }.sum() / numRatings bc_test_itemFactors.unpersist() print(s"$loss\t") testTime += (System.currentTimeMillis() - testTimeStart) @@ -160,7 +160,6 @@ class MatrixFactorization extends Serializable{ val vj = localV.get(ranRating.index_y).get //update vj val residual = ranRating.rating - uh.dot(vj) - val rrr = stepsize * residual vj *= (1 - stepsize * lambda_v) vj.plusax(stepsize * residual, uh) loss += (residual * residual) @@ -178,15 +177,15 @@ class MatrixFactorization extends Serializable{ itemFactors.foreach(ui => ui._2 /= numParts.toDouble) bc_itemFactors.unpersist() - val approxLoss = lossSum / (numParts * numInnerIters) - if (i != 0) { - val oldLoss = lossList.last - if (approxLoss > oldLoss) - stepsize = stepsize * 0.5 - else - stepsize *= 1.05 - } - lossList.append(approxLoss) +// val approxLoss = lossSum / (numParts * numInnerIters) +// if (i != 0) { +// val oldLoss = lossList.last +// if (approxLoss > oldLoss) +// stepsize = stepsize * 0.5 +// else +// stepsize *= 1.05 +// } +// lossList.append(approxLoss) // println(s"approximate loss: $approxLoss, time: ${System.currentTimeMillis() - startTime}")