forked from pool/xgboost
upgrade OBS-URL: https://build.opensuse.org/request/show/1117176 OBS-URL: https://build.opensuse.org/package/show/Java:packages/xgboost?expand=0&rev=18
483 lines
18 KiB
Diff
483 lines
18 KiB
Diff
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
|
|
index 8d4f2c05..7649df65 100644
|
|
--- a/jvm-packages/xgboost4j/pom.xml
|
|
+++ b/jvm-packages/xgboost4j/pom.xml
|
|
@@ -29,18 +29,6 @@
|
|
<artifactId>scala-collection-compat_${scala.binary.version}</artifactId>
|
|
<version>${scala-collection-compat.version}</version>
|
|
</dependency>
|
|
- <dependency>
|
|
- <groupId>org.apache.hadoop</groupId>
|
|
- <artifactId>hadoop-hdfs</artifactId>
|
|
- <version>${hadoop.version}</version>
|
|
- <scope>provided</scope>
|
|
- </dependency>
|
|
- <dependency>
|
|
- <groupId>org.apache.hadoop</groupId>
|
|
- <artifactId>hadoop-common</artifactId>
|
|
- <version>${hadoop.version}</version>
|
|
- <scope>provided</scope>
|
|
- </dependency>
|
|
<dependency>
|
|
<groupId>junit</groupId>
|
|
<artifactId>junit</artifactId>
|
|
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java
|
|
deleted file mode 100644
|
|
index 655b9902..00000000
|
|
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java
|
|
+++ /dev/null
|
|
@@ -1,117 +0,0 @@
|
|
-package ml.dmlc.xgboost4j.java;
|
|
-
|
|
-import java.io.IOException;
|
|
-import java.io.InputStream;
|
|
-import java.io.OutputStream;
|
|
-import java.util.*;
|
|
-import java.util.stream.Collectors;
|
|
-
|
|
-import org.apache.commons.logging.Log;
|
|
-import org.apache.commons.logging.LogFactory;
|
|
-import org.apache.hadoop.fs.FileSystem;
|
|
-import org.apache.hadoop.fs.Path;
|
|
-
|
|
-public class ExternalCheckpointManager {
|
|
-
|
|
- private Log logger = LogFactory.getLog("ExternalCheckpointManager");
|
|
- private String modelSuffix = ".model";
|
|
- private Path checkpointPath;
|
|
- private FileSystem fs;
|
|
-
|
|
- public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
|
|
- if (checkpointPath == null || checkpointPath.isEmpty()) {
|
|
- throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
|
|
- " empty checkpoint path");
|
|
- }
|
|
- this.checkpointPath = new Path(checkpointPath);
|
|
- this.fs = fs;
|
|
- }
|
|
-
|
|
- private String getPath(int version) {
|
|
- return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
|
|
- }
|
|
-
|
|
- private List<Integer> getExistingVersions() throws IOException {
|
|
- if (!fs.exists(checkpointPath)) {
|
|
- return new ArrayList<>();
|
|
- } else {
|
|
- return Arrays.stream(fs.listStatus(checkpointPath))
|
|
- .map(path -> path.getPath().getName())
|
|
- .filter(fileName -> fileName.endsWith(modelSuffix))
|
|
- .map(fileName -> Integer.valueOf(
|
|
- fileName.substring(0, fileName.length() - modelSuffix.length())))
|
|
- .collect(Collectors.toList());
|
|
- }
|
|
- }
|
|
-
|
|
- public void cleanPath() throws IOException {
|
|
- fs.delete(checkpointPath, true);
|
|
- }
|
|
-
|
|
- public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
|
|
- List<Integer> versions = getExistingVersions();
|
|
- if (versions.size() > 0) {
|
|
- int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
|
|
- String checkpointPath = getPath(latestVersion);
|
|
- InputStream in = fs.open(new Path(checkpointPath));
|
|
- logger.info("loaded checkpoint from " + checkpointPath);
|
|
- Booster booster = XGBoost.loadModel(in);
|
|
- booster.setVersion(latestVersion);
|
|
- return booster;
|
|
- } else {
|
|
- return null;
|
|
- }
|
|
- }
|
|
-
|
|
- public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
|
|
- List<String> prevModelPaths = getExistingVersions().stream()
|
|
- .map(this::getPath).collect(Collectors.toList());
|
|
- String eventualPath = getPath(boosterToCheckpoint.getVersion());
|
|
- String tempPath = eventualPath + "-" + UUID.randomUUID();
|
|
- try (OutputStream out = fs.create(new Path(tempPath), true)) {
|
|
- boosterToCheckpoint.saveModel(out);
|
|
- fs.rename(new Path(tempPath), new Path(eventualPath));
|
|
- logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
|
|
- prevModelPaths.stream().forEach(path -> {
|
|
- try {
|
|
- fs.delete(new Path(path), true);
|
|
- } catch (IOException e) {
|
|
- logger.error("failed to delete outdated checkpoint at " + path, e);
|
|
- }
|
|
- });
|
|
- }
|
|
- }
|
|
-
|
|
- public void cleanUpHigherVersions(int currentRound) throws IOException {
|
|
- getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
|
|
- try {
|
|
- fs.delete(new Path(getPath(v)), true);
|
|
- } catch (IOException e) {
|
|
- logger.error("failed to clean checkpoint from other training instance", e);
|
|
- }
|
|
- });
|
|
- }
|
|
-
|
|
- public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
|
|
- throws IOException {
|
|
- if (checkpointInterval > 0) {
|
|
- List<Integer> prevRounds =
|
|
- getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
|
|
- prevRounds.add(0);
|
|
- int firstCheckpointRound = prevRounds.stream()
|
|
- .max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
|
|
- List<Integer> arr = new ArrayList<>();
|
|
- for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
|
|
- arr.add(i);
|
|
- }
|
|
- arr.add(numOfRounds);
|
|
- return arr;
|
|
- } else if (checkpointInterval <= 0) {
|
|
- List<Integer> l = new ArrayList<Integer>();
|
|
- l.add(numOfRounds);
|
|
- return l;
|
|
- } else {
|
|
- throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
|
|
- }
|
|
- }
|
|
-}
|
|
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
|
|
index bcd0b1b1..3e23c15f 100644
|
|
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
|
|
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
|
|
@@ -22,7 +22,6 @@ import java.util.regex.Pattern;
|
|
|
|
import org.apache.commons.logging.Log;
|
|
import org.apache.commons.logging.LogFactory;
|
|
-import org.apache.hadoop.fs.FileSystem;
|
|
|
|
/**
|
|
* trainer for xgboost
|
|
@@ -134,34 +133,35 @@ public class XGBoost {
|
|
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
|
|
}
|
|
|
|
- private static void saveCheckpoint(
|
|
- Booster booster,
|
|
- int iter,
|
|
- Set<Integer> checkpointIterations,
|
|
- ExternalCheckpointManager ecm) throws XGBoostError {
|
|
- try {
|
|
- if (checkpointIterations.contains(iter)) {
|
|
- ecm.updateCheckpoint(booster);
|
|
- }
|
|
- } catch (Exception e) {
|
|
- logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
|
|
- throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
|
|
- }
|
|
- }
|
|
-
|
|
- public static Booster trainAndSaveCheckpoint(
|
|
+ /**
|
|
+ * Train a booster given parameters.
|
|
+ *
|
|
+ * @param dtrain Data to be trained.
|
|
+ * @param params Parameters.
|
|
+ * @param round Number of boosting iterations.
|
|
+ * @param watches a group of items to be evaluated during training, this allows user to watch
|
|
+ * performance on the validation set.
|
|
+ * @param metrics array containing the evaluation metrics for each matrix in watches for each
|
|
+ * iteration
|
|
+ * @param earlyStoppingRounds if non-zero, training would be stopped
|
|
+ * after a specified number of consecutive
|
|
+ * goes to the unexpected direction in any evaluation metric.
|
|
+ * @param obj customized objective
|
|
+ * @param eval customized evaluation
|
|
+ * @param booster train from scratch if set to null; train from an existing booster if not null.
|
|
+ * @return The trained booster.
|
|
+ */
|
|
+ public static Booster train(
|
|
DMatrix dtrain,
|
|
Map<String, Object> params,
|
|
- int numRounds,
|
|
+ int round,
|
|
Map<String, DMatrix> watches,
|
|
float[][] metrics,
|
|
IObjective obj,
|
|
IEvaluation eval,
|
|
int earlyStoppingRounds,
|
|
- Booster booster,
|
|
- int checkpointInterval,
|
|
- String checkpointPath,
|
|
- FileSystem fs) throws XGBoostError, IOException {
|
|
+ Booster booster) throws XGBoostError {
|
|
+
|
|
//collect eval matrixs
|
|
String[] evalNames;
|
|
DMatrix[] evalMats;
|
|
@@ -169,11 +169,6 @@ public class XGBoost {
|
|
int bestIteration;
|
|
List<String> names = new ArrayList<String>();
|
|
List<DMatrix> mats = new ArrayList<DMatrix>();
|
|
- Set<Integer> checkpointIterations = new HashSet<>();
|
|
- ExternalCheckpointManager ecm = null;
|
|
- if (checkpointPath != null) {
|
|
- ecm = new ExternalCheckpointManager(checkpointPath, fs);
|
|
- }
|
|
|
|
for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
|
|
names.add(evalEntry.getKey());
|
|
@@ -184,7 +179,7 @@ public class XGBoost {
|
|
evalMats = mats.toArray(new DMatrix[mats.size()]);
|
|
|
|
bestIteration = 0;
|
|
- metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
|
|
+ metrics = metrics == null ? new float[evalNames.length][round] : metrics;
|
|
|
|
//collect all data matrixs
|
|
DMatrix[] allMats;
|
|
@@ -209,22 +204,17 @@ public class XGBoost {
|
|
booster.setParams(params);
|
|
}
|
|
|
|
- if (ecm != null) {
|
|
- checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
|
|
- }
|
|
-
|
|
boolean initial_best_score_flag = false;
|
|
boolean max_direction = false;
|
|
|
|
// begin to train
|
|
- for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
|
|
+ for (int iter = booster.getVersion() / 2; iter < round; iter++) {
|
|
if (booster.getVersion() % 2 == 0) {
|
|
if (obj != null) {
|
|
booster.update(dtrain, obj);
|
|
} else {
|
|
booster.update(dtrain, iter);
|
|
}
|
|
- saveCheckpoint(booster, iter, checkpointIterations, ecm);
|
|
booster.saveRabitCheckpoint();
|
|
}
|
|
|
|
@@ -290,44 +280,6 @@ public class XGBoost {
|
|
return booster;
|
|
}
|
|
|
|
- /**
|
|
- * Train a booster given parameters.
|
|
- *
|
|
- * @param dtrain Data to be trained.
|
|
- * @param params Parameters.
|
|
- * @param round Number of boosting iterations.
|
|
- * @param watches a group of items to be evaluated during training, this allows user to watch
|
|
- * performance on the validation set.
|
|
- * @param metrics array containing the evaluation metrics for each matrix in watches for each
|
|
- * iteration
|
|
- * @param earlyStoppingRounds if non-zero, training would be stopped
|
|
- * after a specified number of consecutive
|
|
- * goes to the unexpected direction in any evaluation metric.
|
|
- * @param obj customized objective
|
|
- * @param eval customized evaluation
|
|
- * @param booster train from scratch if set to null; train from an existing booster if not null.
|
|
- * @return The trained booster.
|
|
- */
|
|
- public static Booster train(
|
|
- DMatrix dtrain,
|
|
- Map<String, Object> params,
|
|
- int round,
|
|
- Map<String, DMatrix> watches,
|
|
- float[][] metrics,
|
|
- IObjective obj,
|
|
- IEvaluation eval,
|
|
- int earlyStoppingRounds,
|
|
- Booster booster) throws XGBoostError {
|
|
- try {
|
|
- return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
|
|
- earlyStoppingRounds, booster,
|
|
- -1, null, null);
|
|
- } catch (IOException e) {
|
|
- logger.error("training failed in xgboost4j", e);
|
|
- throw new XGBoostError("training failed in xgboost4j ", e);
|
|
- }
|
|
- }
|
|
-
|
|
private static Integer tryGetIntFromObject(Object o) {
|
|
if (o instanceof Integer) {
|
|
return (int)o;
|
|
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala
|
|
deleted file mode 100644
|
|
index 240c2387..00000000
|
|
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala
|
|
+++ /dev/null
|
|
@@ -1,37 +0,0 @@
|
|
-/*
|
|
- Copyright (c) 2014 by Contributors
|
|
-
|
|
- Licensed under the Apache License, Version 2.0 (the "License");
|
|
- you may not use this file except in compliance with the License.
|
|
- You may obtain a copy of the License at
|
|
-
|
|
- http://www.apache.org/licenses/LICENSE-2.0
|
|
-
|
|
- Unless required by applicable law or agreed to in writing, software
|
|
- distributed under the License is distributed on an "AS IS" BASIS,
|
|
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
- See the License for the specific language governing permissions and
|
|
- limitations under the License.
|
|
- */
|
|
-
|
|
-package ml.dmlc.xgboost4j.scala
|
|
-
|
|
-import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
|
|
-import org.apache.hadoop.fs.FileSystem
|
|
-
|
|
-class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
|
|
- extends JavaECM(checkpointPath, fs) {
|
|
-
|
|
- def updateCheckpoint(booster: Booster): Unit = {
|
|
- super.updateCheckpoint(booster.booster)
|
|
- }
|
|
-
|
|
- def loadCheckpointAsScalaBooster(): Booster = {
|
|
- val loadedBooster = super.loadCheckpointAsBooster()
|
|
- if (loadedBooster == null) {
|
|
- null
|
|
- } else {
|
|
- new Booster(loadedBooster)
|
|
- }
|
|
- }
|
|
-}
|
|
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
|
|
index 50d86c89..49ce29d8 100644
|
|
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
|
|
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
|
|
@@ -20,61 +20,12 @@ import java.io.InputStream
|
|
import ml.dmlc.xgboost4j.java.{XGBoostError, XGBoost => JXGBoost}
|
|
|
|
import scala.jdk.CollectionConverters._
|
|
-import org.apache.hadoop.conf.Configuration
|
|
-import org.apache.hadoop.fs.Path
|
|
|
|
/**
|
|
* XGBoost Scala Training function.
|
|
*/
|
|
object XGBoost {
|
|
|
|
- private[scala] def trainAndSaveCheckpoint(
|
|
- dtrain: DMatrix,
|
|
- params: Map[String, Any],
|
|
- numRounds: Int,
|
|
- watches: Map[String, DMatrix] = Map(),
|
|
- metrics: Array[Array[Float]] = null,
|
|
- obj: ObjectiveTrait = null,
|
|
- eval: EvalTrait = null,
|
|
- earlyStoppingRound: Int = 0,
|
|
- prevBooster: Booster,
|
|
- checkpointParams: Option[ExternalCheckpointParams]): Booster = {
|
|
-
|
|
- // we have to filter null value for customized obj and eval
|
|
- val jParams: java.util.Map[String, AnyRef] =
|
|
- params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).toMap.asJava
|
|
-
|
|
- val jWatches = watches.mapValues(_.jDMatrix).toMap.asJava
|
|
- val jBooster = if (prevBooster == null) {
|
|
- null
|
|
- } else {
|
|
- prevBooster.booster
|
|
- }
|
|
-
|
|
- val xgboostInJava = checkpointParams.
|
|
- map(cp => {
|
|
- JXGBoost.trainAndSaveCheckpoint(
|
|
- dtrain.jDMatrix,
|
|
- jParams,
|
|
- numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
|
|
- cp.checkpointInterval,
|
|
- cp.checkpointPath,
|
|
- new Path(cp.checkpointPath).getFileSystem(new Configuration()))
|
|
- }).
|
|
- getOrElse(
|
|
- JXGBoost.train(
|
|
- dtrain.jDMatrix,
|
|
- jParams,
|
|
- numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
|
- )
|
|
- if (prevBooster == null) {
|
|
- new Booster(xgboostInJava)
|
|
- } else {
|
|
- // Avoid creating a new SBooster with the same JBooster
|
|
- prevBooster
|
|
- }
|
|
- }
|
|
-
|
|
/**
|
|
* Train a booster given parameters.
|
|
*
|
|
@@ -104,8 +55,23 @@ object XGBoost {
|
|
eval: EvalTrait = null,
|
|
earlyStoppingRound: Int = 0,
|
|
booster: Booster = null): Booster = {
|
|
- trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
|
|
- booster, None)
|
|
+ val jWatches = watches.mapValues(_.jDMatrix).toMap.asJava
|
|
+ val jBooster = if (booster == null) {
|
|
+ null
|
|
+ } else {
|
|
+ booster.booster
|
|
+ }
|
|
+ val xgboostInJava = JXGBoost.train(
|
|
+ dtrain.jDMatrix,
|
|
+ // we have to filter null value for customized obj and eval
|
|
+ params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).toMap.asJava,
|
|
+ round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
|
|
+ if (booster == null) {
|
|
+ new Booster(xgboostInJava)
|
|
+ } else {
|
|
+ // Avoid creating a new SBooster with the same JBooster
|
|
+ booster
|
|
+ }
|
|
}
|
|
|
|
/**
|
|
@@ -160,41 +126,3 @@ object XGBoost {
|
|
new Booster(xgboostInJava)
|
|
}
|
|
}
|
|
-
|
|
-private[scala] case class ExternalCheckpointParams(
|
|
- checkpointInterval: Int,
|
|
- checkpointPath: String,
|
|
- skipCleanCheckpoint: Boolean)
|
|
-
|
|
-private[scala] object ExternalCheckpointParams {
|
|
-
|
|
- def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
|
|
- val checkpointPath: String = params.get("checkpoint_path") match {
|
|
- case None | Some(null) | Some("") => null
|
|
- case Some(path: String) => path
|
|
- case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
|
|
- s" an instance of String, but current value is ${params("checkpoint_path")}")
|
|
- }
|
|
-
|
|
- val checkpointInterval: Int = params.get("checkpoint_interval") match {
|
|
- case None => 0
|
|
- case Some(freq: Int) => freq
|
|
- case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
|
|
- " an instance of Int.")
|
|
- }
|
|
-
|
|
- val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
|
|
- case None => false
|
|
- case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
|
|
- case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
|
|
- " an instance of Boolean")
|
|
- }
|
|
- if (checkpointPath == null || checkpointInterval == 0) {
|
|
- None
|
|
- } else {
|
|
- Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
|
|
- }
|
|
- }
|
|
-}
|
|
-
|
|
-
|