Revert the changes for IgnoreUndefined management in tree evaluation #491

Merged
claudio.atzori merged 1 commits from fix_decision_tree into main 2024-10-11 10:33:42 +02:00
3 changed files with 12 additions and 31 deletions

View File

@ -48,7 +48,7 @@ public class TreeNodeDef implements Serializable {
// function for the evaluation of the node // function for the evaluation of the node
public TreeNodeStats evaluate(Row doc1, Row doc2, Config conf) { public TreeNodeStats evaluate(Row doc1, Row doc2, Config conf) {
TreeNodeStats stats = new TreeNodeStats(ignoreUndefined); TreeNodeStats stats = new TreeNodeStats();
// for each field in the node, it computes the // for each field in the node, it computes the
for (FieldConf fieldConf : fields) { for (FieldConf fieldConf : fields) {

View File

@ -9,11 +9,8 @@ public class TreeNodeStats implements Serializable {
private Map<String, FieldStats> results; // this is an accumulator for the results of the node private Map<String, FieldStats> results; // this is an accumulator for the results of the node
private final boolean ignoreUndefined; public TreeNodeStats() {
public TreeNodeStats(boolean ignoreUndefined) {
this.results = new HashMap<>(); this.results = new HashMap<>();
this.ignoreUndefined = ignoreUndefined;
} }
public Map<String, FieldStats> getResults() { public Map<String, FieldStats> getResults() {
@ -25,10 +22,7 @@ public class TreeNodeStats implements Serializable {
} }
public int fieldsCount() { public int fieldsCount() {
if (ignoreUndefined) return this.results.size();
return this.results.size();
else
return this.results.size() - undefinedCount(); // do not count undefined
} }
public int undefinedCount() { public int undefinedCount() {
@ -84,22 +78,11 @@ public class TreeNodeStats implements Serializable {
double min = 100.0; // random high value double min = 100.0; // random high value
for (FieldStats fs : this.results.values()) { for (FieldStats fs : this.results.values()) {
if (fs.getResult() < min) { if (fs.getResult() < min) {
if (fs.getResult() == -1) { if (fs.getResult() >= 0.0 || (fs.getResult() == -1 && fs.isCountIfUndefined()))
if (fs.isCountIfUndefined()) {
min = 0.0;
} else {
min = -1;
}
} else {
min = fs.getResult(); min = fs.getResult();
}
} }
} }
if (ignoreUndefined) { return min;
return min == -1.0 ? 0.0 : min;
} else {
return min;
}
} }
// if at least one is true, return 1.0 // if at least one is true, return 1.0
@ -108,11 +91,7 @@ public class TreeNodeStats implements Serializable {
if (fieldStats.getResult() >= fieldStats.getThreshold()) if (fieldStats.getResult() >= fieldStats.getThreshold())
return 1.0; return 1.0;
} }
if (!ignoreUndefined && undefinedCount() > 0) { return 0.0;
return -1.0;
} else {
return 0.0;
}
} }
// if at least one is false, return 0.0 // if at least one is false, return 0.0
@ -121,7 +100,7 @@ public class TreeNodeStats implements Serializable {
if (fieldStats.getResult() == -1) { if (fieldStats.getResult() == -1) {
if (fieldStats.isCountIfUndefined()) if (fieldStats.isCountIfUndefined())
return ignoreUndefined ? 0.0 : -1.0; return 0.0;
} else { } else {
if (fieldStats.getResult() < fieldStats.getThreshold()) if (fieldStats.getResult() < fieldStats.getThreshold())
return 0.0; return 0.0;

View File

@ -44,10 +44,12 @@ public class TreeProcessor {
TreeNodeStats stats = currentNode.evaluate(doc1, doc2, config); TreeNodeStats stats = currentNode.evaluate(doc1, doc2, config);
treeStats.addNodeStats(nextNodeName, stats); treeStats.addNodeStats(nextNodeName, stats);
double finalScore = stats.getFinalScore(currentNode.getAggregation()); // if ignoreUndefined=false the miss is considered as undefined
if (finalScore == -1.0) if (!currentNode.isIgnoreUndefined() && stats.undefinedCount() > 0) {
nextNodeName = currentNode.getUndefined(); nextNodeName = currentNode.getUndefined();
else if (finalScore >= currentNode.getThreshold()) { }
// if ignoreUndefined=true the miss is ignored and the score computed anyway
else if (stats.getFinalScore(currentNode.getAggregation()) >= currentNode.getThreshold()) {
nextNodeName = currentNode.getPositive(); nextNodeName = currentNode.getPositive();
} else { } else {
nextNodeName = currentNode.getNegative(); nextNodeName = currentNode.getNegative();