Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/main/scala/com/beepboop/app/MainApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ import com.beepboop.app.logger.LogTrait
object MainApp extends LogTrait {

def main(args: Array[String]): Unit = {

val configOpt = ArgumentParser.parse(args)
if (configOpt.isEmpty) return
val config = configOpt.get

AppConfig.init(config)

info("--- Step 1: Configuration ---")
info(s"Model: ${config.modelPath}")
info(s"Max Iterations: ${config.maxIterations}")
Expand Down Expand Up @@ -125,9 +126,6 @@ object MainApp extends LogTrait {

info("--- Step 4. Starting constraint picking ---")

ConstraintSaver.setConfig(config)

ConstraintPicker.setConfig(config)
/* vvv comment if no gurobi license present vvv */
//ConstraintPicker.runInitial(result.get)

Expand Down
137 changes: 90 additions & 47 deletions src/main/scala/com/beepboop/app/MinizincModelListener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.beepboop.app.dataprovider.*

import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters.*
import scala.util.Try

class MinizincModelListener(tokens: CommonTokenStream, extendedDataType: Boolean = true) extends NewMinizincParserBaseListener {

Expand All @@ -22,6 +23,23 @@ class MinizincModelListener(tokens: CommonTokenStream, extendedDataType: Boolean
tokens.getText(new Interval(start, stop))
}

private def parseSimpleValue(text: String): Any = {
val t = text.trim
if (t == "true") return true
if (t == "false") return false

Try(t.toInt).toOption match {
case Some(i) => return i
case None =>
}

Try(t.toDouble).toOption match {
case Some(d) => return d
case None =>
}

None
}

override def enterVar_decl_item(ctx: Var_decl_itemContext): Unit = {
assert(ctx != null)
Expand All @@ -33,64 +51,89 @@ class MinizincModelListener(tokens: CommonTokenStream, extendedDataType: Boolean
var detailedFullType = DataType(null, false, false)
var expr = ""
var isVar = false

var initialValue: Any = None

if (ctx.ti_expr_and_id().ti_expr().array_ti_expr() != null) {
if (ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().var_par().getText != "var")
{
// not var
fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr())
if (ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().base_type() == null) {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().ident()), true, true)
} else {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), true, false)
}

// try to get expression for given decl Item, useful for set of int?
if (ctx.expr() != null) {
expr = ctx.expr().getText
}
} else {
fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr())
if (ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().base_type() == null) {
if (ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().ident() != null) {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().ident()), true, true)
} else {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr()), true, true)
}
} else {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), true, false)
}
isVar = true
}

} else if (ctx.ti_expr_and_id().ti_expr().base_ti_expr() != null) {
if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().var_par().getText != "var") {
fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr())
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), false, false, Option(ctx.ti_expr_and_id().ti_expr().base_ti_expr().set_ti() != null).getOrElse(false))
val arrayTi = ctx.ti_expr_and_id().ti_expr().array_ti_expr()
val baseTi = arrayTi.base_ti_expr()

val isSetType = baseTi.set_ti() != null && baseTi.set_ti().getText.toLowerCase.contains("set")

if (baseTi.var_par().getText != "var") {
fullType = textWithSpaces(arrayTi)

val typeName = if (baseTi.base_ti_expr_tail().base_type() == null) {
textWithSpaces(baseTi.base_ti_expr_tail().ident())
} else {
fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr())
if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type() == null){
if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().ident() != null){
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().ident()), false, true)
} else {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr()), false, true)
}
} else {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), false, false)
}
isVar = true
textWithSpaces(baseTi.base_ti_expr_tail().base_type())
}

val isIdent = baseTi.base_ti_expr_tail().base_type() == null

detailedFullType = DataType(typeName, true, isIdent, isSet = isSetType)

if (ctx.expr() != null) {
expr = ctx.expr().getText
}
} else {
fullType = textWithSpaces(arrayTi)
val tail = baseTi.base_ti_expr_tail()

if (tail.base_type() == null) {
if (tail.ident() != null) {
detailedFullType = DataType(textWithSpaces(tail.ident()), true, true, isSet = isSetType)
} else {
detailedFullType = DataType(textWithSpaces(arrayTi), true, true, isSet = isSetType)
}
} else {
detailedFullType = DataType(textWithSpaces(tail.base_type()), true, false, isSet = isSetType)
}
isVar = true
}
} else if (ctx.ti_expr_and_id().ti_expr().base_ti_expr() != null) {
val baseTi = ctx.ti_expr_and_id().ti_expr().base_ti_expr()
val isSetType = baseTi.set_ti() != null && baseTi.set_ti().getText.toLowerCase.contains("set")

if (baseTi.var_par().getText != "var") {
fullType = textWithSpaces(baseTi)
val typeName = if (baseTi.base_ti_expr_tail().base_type() != null) {
textWithSpaces(baseTi.base_ti_expr_tail().base_type())
} else {
baseTi.base_ti_expr_tail().getText
}
detailedFullType = DataType(
dataType = typeName,
isArray = false,
isIdentifier = false,
isSet = isSetType
)
} else {
fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr())
if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type() == null) {
if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().ident() != null) {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().ident()), false, true)
} else {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr()), false, true)
}
} else {
detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), false, false)
}
isVar = true
}

if (ctx.expr() != null) {
expr = ctx.expr().getText
}
}

if (ctx.expr() != null) {
initialValue = parseSimpleValue(ctx.expr().getText)
}

val dataItem = if(extendedDataType){
DataItem(name = name, dataType = fullType, isVar = isVar, detailedDataType = detailedFullType, expr = expr)
val dataItem = if (extendedDataType) {
DataItem(name = name, dataType = fullType, isVar = isVar, value = initialValue, detailedDataType = detailedFullType, expr = expr)
} else {
DataItem(name = name, dataType = fullType, isVar = isVar, detailedDataType = null, expr = expr)
DataItem(name = name, dataType = fullType, isVar = isVar, value = initialValue, detailedDataType = null, expr = expr)
}
dataItemsBuffer += dataItem
}
Expand Down
13 changes: 9 additions & 4 deletions src/main/scala/com/beepboop/app/astar/AStar.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.beepboop.app.astar

import com.beepboop.app.*
import com.beepboop.app.components.{BinaryExpression, Expression}
import com.beepboop.app.components.{BinaryExpression, BoolType, Expression}
import com.beepboop.app.dataprovider.{DataItem, DataProvider}
import com.beepboop.app.logger.LogTrait
import com.beepboop.app.logger.Profiler
Expand All @@ -12,7 +12,7 @@ import org.antlr.v4.runtime.{CharStreams, CommonTokenStream}

import scala.collection.parallel.CollectionConverters.RangeIsParallelizable
import com.beepboop.app.dataprovider.{AStarSnapshot, PersistenceManager}
import com.beepboop.app.policy.{Compliant, DenyDivByZero, EnsureAnyVarExists, NonCompliant, Scanner}
import com.beepboop.app.policy.{Compliant, DenyDivByZero, EnsureAnyVarExists, MaxDepth, NonCompliant, Scanner}
import com.beepboop.app.postprocessor.Postprocessor

import scala.collection.mutable
Expand All @@ -38,6 +38,10 @@ case class ModelNodeTMP(
) extends Serializable {
val f: Int = g + h

override def toString: String = {
s"[f=$f, g=$g, h=$h] ${constraint.toString}"
}

override def equals(obj: Any): Boolean = obj match {
case that: ModelNodeTMP => this.constraint == that.constraint
case _ => false
Expand Down Expand Up @@ -111,7 +115,8 @@ class AStar(grammar: ParsedGrammar, heuristicMode: String = "avg") extends LogTr
checkpointFile: String,
outputCsvFile: String
): Option[mutable.Set[ModelNodeTMP]] = {

require(initialConstraint.signature.output == BoolType,
s"Initial constraint must return BoolType, but got ${initialConstraint.signature.output}")
val gScore = mutable.Map[Expression[?], Int]()

if (!isInitialized) {
Expand Down Expand Up @@ -406,7 +411,7 @@ class AStar(grammar: ParsedGrammar, heuristicMode: String = "avg") extends LogTr
}

debug(s"Generated: $candidateTree to simplified $simplifiedTree")
val result = Scanner.visitAll(simplifiedTree, EnsureAnyVarExists(), DenyDivByZero())
val result = Scanner.visitAll(simplifiedTree, EnsureAnyVarExists(), DenyDivByZero(), MaxDepth(5))

if (result.isAllowed) {
Profiler.recordValue("accepted", 1)
Expand Down
9 changes: 9 additions & 0 deletions src/main/scala/com/beepboop/app/components/Arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,20 @@ object Orable {

trait Xorable[T] extends Serializable{
def xor(a: T, b: T): T
def distance(leftDist: Integer, rightDist: Integer, a: T, b: T): Integer
}

object Xorable {
implicit object BoolIsXorable extends Xorable[Boolean] {
override def xor(a: Boolean, b: Boolean): Boolean = a ^ b

override def distance(leftDist: Integer, rightDist: Integer, leftEval: Boolean, rightEval: Boolean): Integer = {
if (leftEval ^ rightEval) {
Math.min(leftDist, rightDist)
} else {
Math.max(1, Math.min(leftDist, rightDist))
}
}
}
}

Expand Down
46 changes: 33 additions & 13 deletions src/main/scala/com/beepboop/app/components/ComponentRegistry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import com.beepboop.app.components.Operator
import com.beepboop.app.components.Expression
import com.beepboop.app.components.SetIntContainsInt
import com.beepboop.app.components.StrEqExpression.StrEqFactory
import com.beepboop.app.dataprovider.{ConfigLoader, DataProvider}
import com.beepboop.app.dataprovider.DataProvider
import com.beepboop.app.logger.LogTrait
import com.beepboop.app.utils.AppConfig
import org.yaml.snakeyaml.Yaml

import scala.jdk.CollectionConverters.*
Expand All @@ -20,6 +21,7 @@ sealed trait ExpressionType
case object IntType extends ExpressionType
case object BoolType extends ExpressionType
case object ListIntType extends ExpressionType
case object ListBoolType extends ExpressionType
case object SetIntType extends ExpressionType
case object IteratorType extends ExpressionType
case object UnknownType extends ExpressionType
Expand All @@ -44,6 +46,8 @@ def scalaTypeToExprType(cls: Class[?]): ExpressionType = cls match {
BoolType
case c if c == classOf[List[Integer]] =>
ListIntType
case c if c == classOf[List[Boolean]] =>
ListBoolType

case c if c == classOf[Set[Int]] || classOf[Set[?]].isAssignableFrom(c) =>
SetIntType
Expand Down Expand Up @@ -80,8 +84,8 @@ object ComponentRegistry extends LogTrait {
new AddOperator[Integer],
new SubOperator[Integer],
new MulOperator[Integer],
new DivOperator[Integer],
new ModOperator[Integer],
//new DivOperator[Integer],
//new ModOperator[Integer],


// relational
Expand All @@ -105,8 +109,8 @@ object ComponentRegistry extends LogTrait {
new XorOperator[Boolean],
new ImpliesOperator[Boolean],

new ContainsOperator[List[Integer], Integer],
new ContainsOperator[Set[Int], Int]
new ContainsOperator[List[Integer], Integer],
new ContainsOperator[Set[Int], Int]
)

private val unaryOperators: List[UnaryOperator[?]] = List(
Expand All @@ -123,11 +127,11 @@ object ComponentRegistry extends LogTrait {

private val allConstantFactories: List[Creatable] = List(
Constant.asCreatable[Integer](() => scala.util.Random.nextInt(10)),
Constant.asCreatable[Boolean](() => scala.util.Random.nextBoolean())
//Constant.asCreatable[Boolean](() => scala.util.Random.nextBoolean())
)
private val allArrayElementFactories: List[Creatable] = List(
ArrayElement.asCreatable[Integer](),
ArrayElement.asCreatable[Boolean](),
ArrayElement.asCreatable[List[Integer]]()
)

private val expressionFactories: List[Creatable] = List(
Expand All @@ -142,23 +146,39 @@ object ComponentRegistry extends LogTrait {
DiffnExpression.DiffnFactory,
ValuePrecedesChainExpression.ValuePrecedesChainFactory,
StrEqExpression.StrEqFactory,
AllDifferentExceptZeroExpression.Factory,
ArgSortExpression.Factory,
SymmetryBreakingExpression.Factory,
SetComprehensionExpression.IntSetComprehensionFactory
//CumulativeExpression
//LexicographicalExpression.asCreatable()
)


val creatables: List[Creatable] = (
binaryOperators.map(op => BinaryExpression.asCreatable(op)) ++
val staticCreatables: List[Creatable] = (
binaryOperators.map(op => BinaryExpression.asCreatable(op)) ++
unaryOperators.map(op => UnaryExpression.asCreatable(op)) ++
allConstantFactories ++
allVariablesFactories ++
expressionFactories ++
allArrayElementFactories
).filter(c =>
val className = c.toString
ConfigLoader.getWeight(className) > 0.0
AppConfig.getWeight(className) > 0.0
)
debug(s"creatables: $creatables")

def creatables: List[Creatable] = {
val currentVariables = DataProvider.getVariableCreatables
.filter(v => v.templateSignature.output != UnknownType)
val all = staticCreatables ++ currentVariables
all
}

creatables.foreach { c =>
val sig = c.templateSignature
val inputs = sig.inputs.map(_.toString).mkString(", ")
debug(s"COMPONENT: ${c.toString} | OUTPUT: ${sig.output} | INPUTS: ($inputs)")
}




Expand All @@ -180,4 +200,4 @@ object ComponentRegistry extends LogTrait {
allOperators.filter(_.signature == sig)
}

}
}
Loading