diff --git a/src/main/scala/com/beepboop/app/MainApp.scala b/src/main/scala/com/beepboop/app/MainApp.scala index 26a7cd3..034dc53 100644 --- a/src/main/scala/com/beepboop/app/MainApp.scala +++ b/src/main/scala/com/beepboop/app/MainApp.scala @@ -28,12 +28,11 @@ import com.beepboop.app.logger.LogTrait object MainApp extends LogTrait { def main(args: Array[String]): Unit = { - val configOpt = ArgumentParser.parse(args) if (configOpt.isEmpty) return val config = configOpt.get - com.beepboop.app.dataprovider.ConfigLoader.initialize(config.configPath) + AppConfig.init(config) info("--- Step 1: Configuration ---") info(s"Model: ${config.modelPath}") @@ -127,9 +126,6 @@ object MainApp extends LogTrait { info("--- Step 4. Starting constraint picking ---") - ConstraintSaver.setConfig(config) - - ConstraintPicker.setConfig(config) /* vvv comment if no gurobi license present vvv */ //ConstraintPicker.runInitial(result.get) diff --git a/src/main/scala/com/beepboop/app/MinizincModelListener.scala b/src/main/scala/com/beepboop/app/MinizincModelListener.scala index 9ec58cc..40887ea 100644 --- a/src/main/scala/com/beepboop/app/MinizincModelListener.scala +++ b/src/main/scala/com/beepboop/app/MinizincModelListener.scala @@ -8,6 +8,7 @@ import com.beepboop.app.dataprovider.* import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters.* +import scala.util.Try class MinizincModelListener(tokens: CommonTokenStream, extendedDataType: Boolean = true) extends NewMinizincParserBaseListener { @@ -22,6 +23,23 @@ class MinizincModelListener(tokens: CommonTokenStream, extendedDataType: Boolean tokens.getText(new Interval(start, stop)) } + private def parseSimpleValue(text: String): Any = { + val t = text.trim + if (t == "true") return true + if (t == "false") return false + + Try(t.toInt).toOption match { + case Some(i) => return i + case None => + } + + Try(t.toDouble).toOption match { + case Some(d) => return d + case None => + } + + None + } override def enterVar_decl_item(ctx: Var_decl_itemContext): Unit = { assert(ctx != null) @@ -33,64 +51,89 @@ class MinizincModelListener(tokens: CommonTokenStream, extendedDataType: Boolean var detailedFullType = DataType(null, false, false) var expr = "" var isVar = false - + var initialValue: Any = None if (ctx.ti_expr_and_id().ti_expr().array_ti_expr() != null) { - if (ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().var_par().getText != "var") - { - // not var - fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr()) - if (ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().base_type() == null) { - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().ident()), true, true) - } else { - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), true, false) - } - - // try to get expression for given decl Item, useful for set of int? - if (ctx.expr() != null) { - expr = ctx.expr().getText - } - } else { - fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr()) - if (ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().base_type() == null) { - if (ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().ident() != null) { - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().ident()), true, true) - } else { - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr()), true, true) - } - } else { - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().array_ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), true, false) - } - isVar = true - } - - } else if (ctx.ti_expr_and_id().ti_expr().base_ti_expr() != null) { - if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().var_par().getText != "var") { - fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr()) - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), false, false, Option(ctx.ti_expr_and_id().ti_expr().base_ti_expr().set_ti() != null).getOrElse(false)) + val arrayTi = ctx.ti_expr_and_id().ti_expr().array_ti_expr() + val baseTi = arrayTi.base_ti_expr() + + val isSetType = baseTi.set_ti() != null && baseTi.set_ti().getText.toLowerCase.contains("set") + + if (baseTi.var_par().getText != "var") { + fullType = textWithSpaces(arrayTi) + + val typeName = if (baseTi.base_ti_expr_tail().base_type() == null) { + textWithSpaces(baseTi.base_ti_expr_tail().ident()) } else { - fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr()) - if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type() == null){ - if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().ident() != null){ - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().ident()), false, true) - } else { - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr()), false, true) - } - } else { - detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), false, false) - } - isVar = true + textWithSpaces(baseTi.base_ti_expr_tail().base_type()) } + val isIdent = baseTi.base_ti_expr_tail().base_type() == null + + detailedFullType = DataType(typeName, true, isIdent, isSet = isSetType) + if (ctx.expr() != null) { expr = ctx.expr().getText } + } else { + fullType = textWithSpaces(arrayTi) + val tail = baseTi.base_ti_expr_tail() + + if (tail.base_type() == null) { + if (tail.ident() != null) { + detailedFullType = DataType(textWithSpaces(tail.ident()), true, true, isSet = isSetType) + } else { + detailedFullType = DataType(textWithSpaces(arrayTi), true, true, isSet = isSetType) + } + } else { + detailedFullType = DataType(textWithSpaces(tail.base_type()), true, false, isSet = isSetType) + } + isVar = true + } + } else if (ctx.ti_expr_and_id().ti_expr().base_ti_expr() != null) { + val baseTi = ctx.ti_expr_and_id().ti_expr().base_ti_expr() + val isSetType = baseTi.set_ti() != null && baseTi.set_ti().getText.toLowerCase.contains("set") + + if (baseTi.var_par().getText != "var") { + fullType = textWithSpaces(baseTi) + val typeName = if (baseTi.base_ti_expr_tail().base_type() != null) { + textWithSpaces(baseTi.base_ti_expr_tail().base_type()) + } else { + baseTi.base_ti_expr_tail().getText + } + detailedFullType = DataType( + dataType = typeName, + isArray = false, + isIdentifier = false, + isSet = isSetType + ) + } else { + fullType = textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr()) + if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type() == null) { + if (ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().ident() != null) { + detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().ident()), false, true) + } else { + detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr()), false, true) + } + } else { + detailedFullType = DataType(textWithSpaces(ctx.ti_expr_and_id().ti_expr().base_ti_expr().base_ti_expr_tail().base_type()), false, false) + } + isVar = true + } + + if (ctx.expr() != null) { + expr = ctx.expr().getText + } + } + + if (ctx.expr() != null) { + initialValue = parseSimpleValue(ctx.expr().getText) } - val dataItem = if(extendedDataType){ - DataItem(name = name, dataType = fullType, isVar = isVar, detailedDataType = detailedFullType, expr = expr) + val dataItem = if (extendedDataType) { + DataItem(name = name, dataType = fullType, isVar = isVar, value = initialValue, detailedDataType = detailedFullType, expr = expr) } else { - DataItem(name = name, dataType = fullType, isVar = isVar, detailedDataType = null, expr = expr) + DataItem(name = name, dataType = fullType, isVar = isVar, value = initialValue, detailedDataType = null, expr = expr) } dataItemsBuffer += dataItem } diff --git a/src/main/scala/com/beepboop/app/astar/AStar.scala b/src/main/scala/com/beepboop/app/astar/AStar.scala index d3b9385..4475d76 100644 --- a/src/main/scala/com/beepboop/app/astar/AStar.scala +++ b/src/main/scala/com/beepboop/app/astar/AStar.scala @@ -1,7 +1,7 @@ package com.beepboop.app.astar import com.beepboop.app.* -import com.beepboop.app.components.{BinaryExpression, Expression} +import com.beepboop.app.components.{BinaryExpression, BoolType, Expression} import com.beepboop.app.dataprovider.{DataItem, DataProvider} import com.beepboop.app.logger.LogTrait import com.beepboop.app.logger.Profiler @@ -12,7 +12,7 @@ import org.antlr.v4.runtime.{CharStreams, CommonTokenStream} import scala.collection.parallel.CollectionConverters.RangeIsParallelizable import com.beepboop.app.dataprovider.{AStarSnapshot, PersistenceManager} -import com.beepboop.app.policy.{Compliant, DenyDivByZero, EnsureAnyVarExists, NonCompliant, Scanner} +import com.beepboop.app.policy.{Compliant, DenyDivByZero, EnsureAnyVarExists, MaxDepth, NonCompliant, Scanner} import com.beepboop.app.postprocessor.Postprocessor import scala.collection.mutable @@ -38,6 +38,10 @@ case class ModelNodeTMP( ) extends Serializable { val f: Int = g + h + override def toString: String = { + s"[f=$f, g=$g, h=$h] ${constraint.toString}" + } + override def equals(obj: Any): Boolean = obj match { case that: ModelNodeTMP => this.constraint == that.constraint case _ => false @@ -111,7 +115,8 @@ class AStar(grammar: ParsedGrammar, heuristicMode: String = "avg") extends LogTr checkpointFile: String, outputCsvFile: String ): Option[mutable.Set[ModelNodeTMP]] = { - + require(initialConstraint.signature.output == BoolType, + s"Initial constraint must return BoolType, but got ${initialConstraint.signature.output}") val gScore = mutable.Map[Expression[?], Int]() if (!isInitialized) { @@ -406,7 +411,7 @@ class AStar(grammar: ParsedGrammar, heuristicMode: String = "avg") extends LogTr } debug(s"Generated: $candidateTree to simplified $simplifiedTree") - val result = Scanner.visitAll(simplifiedTree, EnsureAnyVarExists(), DenyDivByZero()) + val result = Scanner.visitAll(simplifiedTree, EnsureAnyVarExists(), DenyDivByZero(), MaxDepth(5)) if (result.isAllowed) { Profiler.recordValue("accepted", 1) diff --git a/src/main/scala/com/beepboop/app/components/Arithmetic.scala b/src/main/scala/com/beepboop/app/components/Arithmetic.scala index 24cb96e..ca2ccc6 100644 --- a/src/main/scala/com/beepboop/app/components/Arithmetic.scala +++ b/src/main/scala/com/beepboop/app/components/Arithmetic.scala @@ -84,11 +84,20 @@ object Orable { trait Xorable[T] extends Serializable{ def xor(a: T, b: T): T + def distance(leftDist: Integer, rightDist: Integer, a: T, b: T): Integer } object Xorable { implicit object BoolIsXorable extends Xorable[Boolean] { override def xor(a: Boolean, b: Boolean): Boolean = a ^ b + + override def distance(leftDist: Integer, rightDist: Integer, leftEval: Boolean, rightEval: Boolean): Integer = { + if (leftEval ^ rightEval) { + Math.min(leftDist, rightDist) + } else { + Math.max(1, Math.min(leftDist, rightDist)) + } + } } } diff --git a/src/main/scala/com/beepboop/app/components/ComponentRegistry.scala b/src/main/scala/com/beepboop/app/components/ComponentRegistry.scala index dfb4a07..137403f 100644 --- a/src/main/scala/com/beepboop/app/components/ComponentRegistry.scala +++ b/src/main/scala/com/beepboop/app/components/ComponentRegistry.scala @@ -8,8 +8,9 @@ import com.beepboop.app.components.Operator import com.beepboop.app.components.Expression import com.beepboop.app.components.SetIntContainsInt import com.beepboop.app.components.StrEqExpression.StrEqFactory -import com.beepboop.app.dataprovider.{ConfigLoader, DataProvider} +import com.beepboop.app.dataprovider.DataProvider import com.beepboop.app.logger.LogTrait +import com.beepboop.app.utils.AppConfig import org.yaml.snakeyaml.Yaml import scala.jdk.CollectionConverters.* @@ -20,6 +21,7 @@ sealed trait ExpressionType case object IntType extends ExpressionType case object BoolType extends ExpressionType case object ListIntType extends ExpressionType +case object ListBoolType extends ExpressionType case object SetIntType extends ExpressionType case object IteratorType extends ExpressionType case object UnknownType extends ExpressionType @@ -44,6 +46,8 @@ def scalaTypeToExprType(cls: Class[?]): ExpressionType = cls match { BoolType case c if c == classOf[List[Integer]] => ListIntType + case c if c == classOf[List[Boolean]] => + ListBoolType case c if c == classOf[Set[Int]] || classOf[Set[?]].isAssignableFrom(c) => SetIntType @@ -80,8 +84,8 @@ object ComponentRegistry extends LogTrait { new AddOperator[Integer], new SubOperator[Integer], new MulOperator[Integer], - new DivOperator[Integer], - new ModOperator[Integer], + //new DivOperator[Integer], + //new ModOperator[Integer], // relational @@ -105,8 +109,8 @@ object ComponentRegistry extends LogTrait { new XorOperator[Boolean], new ImpliesOperator[Boolean], - new ContainsOperator[List[Integer], Integer], - new ContainsOperator[Set[Int], Int] + new ContainsOperator[List[Integer], Integer], + new ContainsOperator[Set[Int], Int] ) private val unaryOperators: List[UnaryOperator[?]] = List( @@ -123,11 +127,11 @@ object ComponentRegistry extends LogTrait { private val allConstantFactories: List[Creatable] = List( Constant.asCreatable[Integer](() => scala.util.Random.nextInt(10)), - Constant.asCreatable[Boolean](() => scala.util.Random.nextBoolean()) + //Constant.asCreatable[Boolean](() => scala.util.Random.nextBoolean()) ) private val allArrayElementFactories: List[Creatable] = List( ArrayElement.asCreatable[Integer](), - ArrayElement.asCreatable[Boolean](), + ArrayElement.asCreatable[List[Integer]]() ) private val expressionFactories: List[Creatable] = List( @@ -142,23 +146,39 @@ object ComponentRegistry extends LogTrait { DiffnExpression.DiffnFactory, ValuePrecedesChainExpression.ValuePrecedesChainFactory, StrEqExpression.StrEqFactory, + AllDifferentExceptZeroExpression.Factory, + ArgSortExpression.Factory, + SymmetryBreakingExpression.Factory, + SetComprehensionExpression.IntSetComprehensionFactory //CumulativeExpression //LexicographicalExpression.asCreatable() ) - val creatables: List[Creatable] = ( - binaryOperators.map(op => BinaryExpression.asCreatable(op)) ++ + val staticCreatables: List[Creatable] = ( + binaryOperators.map(op => BinaryExpression.asCreatable(op)) ++ unaryOperators.map(op => UnaryExpression.asCreatable(op)) ++ allConstantFactories ++ - allVariablesFactories ++ expressionFactories ++ allArrayElementFactories ).filter(c => val className = c.toString - ConfigLoader.getWeight(className) > 0.0 + AppConfig.getWeight(className) > 0.0 ) - debug(s"creatables: $creatables") + + def creatables: List[Creatable] = { + val currentVariables = DataProvider.getVariableCreatables + .filter(v => v.templateSignature.output != UnknownType) + val all = staticCreatables ++ currentVariables + all + } + + creatables.foreach { c => + val sig = c.templateSignature + val inputs = sig.inputs.map(_.toString).mkString(", ") + debug(s"COMPONENT: ${c.toString} | OUTPUT: ${sig.output} | INPUTS: ($inputs)") + } + @@ -180,4 +200,4 @@ object ComponentRegistry extends LogTrait { allOperators.filter(_.signature == sig) } -} \ No newline at end of file +} diff --git a/src/main/scala/com/beepboop/app/components/Expression.scala b/src/main/scala/com/beepboop/app/components/Expression.scala index f9ec2c5..2ea9007 100644 --- a/src/main/scala/com/beepboop/app/components/Expression.scala +++ b/src/main/scala/com/beepboop/app/components/Expression.scala @@ -2,7 +2,7 @@ package com.beepboop.app.components import com.beepboop.app.components.* import com.beepboop.app.dataprovider.{DataProvider, VarNameGenerator} -import com.beepboop.app.logger.LogTrait +import com.beepboop.app.logger.{LogTrait, Profiler} import com.beepboop.app.mutations.{ContextAwareCreatable, GenerationContext} import com.beepboop.app.policy.{EnsureSpecificVarExists, NoDuplicateVar, Policy} import com.beepboop.app.postprocessor.Postprocessor @@ -10,6 +10,7 @@ import com.beepboop.app.utils.Implicits.integerNumeric import java.lang.Integer.sum import java.util +import scala.reflect.ensureAccessible /* third party modules */ @@ -44,6 +45,7 @@ trait AutoNamed { } abstract class Expression[ReturnT](implicit val ct: ClassTag[ReturnT]) extends LogTrait with Serializable { + var creatorInfo: String = "Unknown" def toString: String def eval(context: Map[String, Any]): ReturnT def evalToString: String @@ -51,17 +53,37 @@ abstract class Expression[ReturnT](implicit val ct: ClassTag[ReturnT]) extends L def distance(context: Map[String, Any]): Int = { 0 } - def exprDepth: Int = this match { - case c: ComposableExpression => 1 + c.children.map(_.exprDepth).sum + + def complexity: Int = this match { + case c: ComposableExpression => + 1 + c.children.map (_.complexity).sum case _ => 1 } + + def depth: Int = this match { + case c: ComposableExpression if c.children.nonEmpty => + 1 + c.children.map(_.depth).max + case _ => 1 + } + + def structuralSignature: String = this match { + case v: Variable[_] => "VAR" + case c: Constant[_] => "CONST" + case comp: ComposableExpression => + val opName = comp match { + case oc: OperatorContainer => oc.operator.toString + case _ => comp.getClass.getSimpleName + } + s"$opName(${comp.children.map(_.structuralSignature).mkString(",")})" + case _ => "UNKNOWN" + } def symbolCount: Int = this match { case c: ComposableExpression => c.children.map(_.symbolCount).sum - case _ => 1 + case v: Variable[?] => 1 } } -case class Variable[ReturnT : ClassTag ](name: String) extends Expression[ReturnT] { +case class Variable[ReturnT : ClassTag ](name: String, domain: Option[Expression[?]] = None) extends Expression[ReturnT] { override def toString: String = name override def evalToString: String = eval.toString override def signature: Signature = { @@ -74,6 +96,16 @@ case class Variable[ReturnT : ClassTag ](name: String) extends Expression[Return case None => throw new NoSuchElementException(s"Variable '$name' not found in evaluation context.") } } + + def getOffset(context: Map[String, Any]): Int = domain match { + case Some(expr) => + expr.eval(context) match { + case (start: Int, _) => start + case list: List[Int] @unchecked if list.nonEmpty => list.min + case _ => 0 + } + case None => 0 + } } @@ -120,9 +152,31 @@ case class ArrayElement[ReturnT : ClassTag]( } override def eval(context: Map[String, Any]): ReturnT = { - variable.eval(context).apply(index.eval(context)) - } + try { + val rawIndex = index.eval(context) + val offset = index match { + case v: Variable[Integer] @unchecked => v.getOffset(context) + case _ => 0 + } + val adjustedIndex = rawIndex - offset + variable.eval(context).apply(adjustedIndex) + } catch { + case e: IndexOutOfBoundsException => + val hasDomain = index match { + case v: Variable[_] => v.domain.isDefined + case _ => false + } + if (!hasDomain) { + Profiler.recordValue("Index out of bounds (ArrayElement without Domain)", 1) + } else { + Profiler.recordValue("Index out of bounds (ArrayElement with Domain)", 1) + } + throw e + case e: Exception => + throw e + } + } override def toString: String = s"${variable.toString}[${index.toString}]" override def evalToString: String = s"${variable.toString}[${index.evalToString}]" @@ -139,18 +193,32 @@ object ArrayElement { def asCreatable[T: ClassTag](): Creatable = new Creatable with AutoNamed { override def templateSignature: Signature = { val listInputType = scalaTypeToExprType(classTag[List[T]].runtimeClass) - val intInputType = scalaTypeToExprType(classTag[Integer].runtimeClass) + val intInputType = scalaTypeToExprType(classOf[Integer]) val singleOutputType = scalaTypeToExprType(classTag[T].runtimeClass) Signature(inputs = List(listInputType, intInputType), output = singleOutputType) } override def create(children: List[Expression[?]]): Expression[T] = { require(children.length == 2, "ArrayElement requires two children.") - ArrayElement[T](children(0).asInstanceOf[Expression[List[T]]], children(1).asInstanceOf[Expression[Integer]]) + val arrayExpr = children(0) + val expectedListType = templateSignature.inputs.head + + if (arrayExpr.signature.output != expectedListType) { + + throw new ClassCastException( + s"Cannot create ArrayElement returning ${templateSignature.output} " + + s"using a collection of type ${arrayExpr.signature.output}" + ) + } + + ArrayElement[T]( + arrayExpr.asInstanceOf[Expression[List[T]]], + children(1).asInstanceOf[Expression[Integer]] + ) } - override def ownerClass: Class[_] = ArrayElement.getClass + override def ownerClass: Class[_] = ArrayElement.getClass // [cite: 180] } } @@ -223,13 +291,44 @@ case class BinaryExpression[ReturnT : ClassTag]( override def signature: Signature = operator.signature override def distance(context: Map[String, Any]): Int = { + val leftVal = left.eval(context) + val rightVal = right.eval(context) + val isSatisfied = operator.eval(leftVal, rightVal).asInstanceOf[Boolean] + + operator match { - case _: AndOperator[_] => + case _: XorOperator[_] | _: EqualOperator[_] | _: NotEqualOperator[_] + if leftVal.isInstanceOf[Boolean] && rightVal.isInstanceOf[Boolean] => + + val lDist = left.distance(context) + val rDist = right.distance(context) + + if (isSatisfied) { + operator match { + case _: XorOperator[_] => + Math.min(lDist, rDist) + case _ => + Math.max(lDist, rDist) + } + } else { + 1 + } + case _: AndOperator[_] => left.distance(context) + right.distance(context) case _: OrOperator[_] => Math.min(left.distance(context), right.distance(context)) case _: ImpliesOperator[_] => - if (left.eval(context).asInstanceOf[Boolean]) right.distance(context) else 0 + val pVal = left.eval(context).asInstanceOf[Boolean] + val qVal = right.eval(context).asInstanceOf[Boolean] + + val distToFlipP = left.distance(context) + val distToFlipQ = right.distance(context) + + if (pVal && !qVal) { + Math.min(distToFlipP, distToFlipQ) + } else { + Math.max(distToFlipP, distToFlipQ) + } case _ => operator.distance(left.eval(context), right.eval(context)) } @@ -238,6 +337,7 @@ case class BinaryExpression[ReturnT : ClassTag]( object BinaryExpression { def asCreatable[T](op: BinaryOperator[T]): Creatable = new Creatable with AutoNamed { + override def toString: String = op.getClass.getSimpleName.stripSuffix("$") override def templateSignature: Signature = op.signature override def create(children: List[Expression[?]]): Expression[T] = { require(children.length == 2) @@ -289,6 +389,7 @@ case class UnaryExpression[ReturnT : ClassTag]( object UnaryExpression { def asCreatable[T](op: UnaryOperator[T]): Creatable = new Creatable with AutoNamed { + override def toString: String = op.getClass.getSimpleName.stripSuffix("$") override def templateSignature: Signature = op.signature override def create(children: List[Expression[?]]): Expression[T] = { implicit val tag: ClassTag[T] = op.ct @@ -432,7 +533,6 @@ object ForAllExpression { override def templateSignature: Signature = Signature(List(ListIntType, BoolType), BoolType) - // Standard create is not used by the generator anymore for this factory override def create(children: List[Expression[?]]): Expression[?] = throw new UnsupportedOperationException("This factory requires context-aware generation.") @@ -445,8 +545,7 @@ object ForAllExpression { if (collectionOpt.isEmpty) return None val varName = VarNameGenerator.generateUniqueName("f") - - val newCtx = ctx.withVariable(varName, IntType) + val newCtx = ctx.withVariable(varName, IntType, collectionOpt.get) val bodyOpt = recurse(BoolType, newCtx).map(_.asInstanceOf[Expression[Boolean]]) @@ -1391,4 +1490,220 @@ object StrEqExpression { override def ownerClass: Class[_] = StrEqExpression.getClass } -} \ No newline at end of file +} + + +case class SetComprehensionExpression[T: ClassTag]( + head: Expression[T], + iteratorDef: IteratorDef[Integer], + filter: Expression[Boolean] + ) extends Expression[List[T]] + with ComposableExpression + with ScopeModifier { + + override def children: List[Expression[?]] = List(head, iteratorDef, filter) + + override def withNewChildren(newChildren: List[Expression[?]]): Expression[?] = { + require(newChildren.length == 3, "SetComprehensionExpression requires 3 children") + this.copy( + head = newChildren(0).asInstanceOf[Expression[T]], + iteratorDef = newChildren(1).asInstanceOf[IteratorDef[Integer]], + filter = newChildren(2).asInstanceOf[Expression[Boolean]] + ) + } + + override def getAdditionalPolicies: List[Policy] = { + List(EnsureSpecificVarExists(iteratorDef.variableName)) + } + + override def eval(context: Map[String, Any]): List[T] = { + val (varName, domain) = iteratorDef.eval(context) + + domain.flatMap { item => + val localContext = context + (varName -> item) + if (filter.eval(localContext)) { + Some(head.eval(localContext)) + } else { + None + } + } + } + + override def distance(context: Map[String, Any]): Int = 0 + + override def toString: String = s"{ $head | $iteratorDef where $filter }" + override def evalToString: String = s"{ $head | $iteratorDef where $filter }" + + override def signature: Signature = { + val outputType = scalaTypeToExprType(classTag[List[T]].runtimeClass) + Signature(inputs = Nil, output = outputType) + } +} + +object SetComprehensionExpression { + object IntSetComprehensionFactory extends Creatable with AutoNamed with ContextAwareCreatable { + + override def templateSignature: Signature = + Signature( + inputs = List(ListIntType, BoolType, IntType), + output = ListIntType + ) + + override def create(children: List[Expression[?]]): Expression[?] = + throw new UnsupportedOperationException("SetComprehension requires context-aware generation.") + + override def generateExpression( + ctx: GenerationContext, + recurse: (ExpressionType, GenerationContext) => Option[Expression[?]] + ): Option[Expression[?]] = { + + val collectionOpt = recurse(ListIntType, ctx).map(_.asInstanceOf[Expression[List[Integer]]]) + if (collectionOpt.isEmpty) return None + + val varName = VarNameGenerator.generateUniqueName("idx") + val newCtx = ctx.withVariable(varName, IntType, collectionOpt.get) + + val filterOpt = recurse(BoolType, newCtx).map(_.asInstanceOf[Expression[Boolean]]) + if (filterOpt.isEmpty) return None + + val headOpt = recurse(IntType, newCtx).map(_.asInstanceOf[Expression[Integer]]) + if (headOpt.isEmpty) return None + + val iterator = IteratorDef(varName, collectionOpt.get) + + Some(SetComprehensionExpression(headOpt.get, iterator, filterOpt.get)) + } + + override def ownerClass: Class[_] = SetComprehensionExpression.getClass + } +} + +case class AllDifferentExceptZeroExpression( + expr: Expression[List[Integer]] + ) extends Expression[Boolean] with ComposableExpression { + + override def children: List[Expression[?]] = List(expr) + + override def withNewChildren(newChildren: List[Expression[?]]): Expression[?] = { + require(newChildren.length == 1) + AllDifferentExceptZeroExpression(newChildren.head.asInstanceOf[Expression[List[Integer]]]) + } + + override def eval(context: Map[String, Any]): Boolean = { + val list = expr.eval(context) + val nonZero = list.filter(_ != 0) + nonZero.distinct.size == nonZero.size + } + + override def distance(context: Map[String, Any]): Int = { + val list = expr.eval(context) + val nonZero = list.filter(_ != 0) + + val counts = nonZero.groupBy(identity).view.mapValues(_.size).toMap + counts.values.map(c => if (c > 1) c - 1 else 0).sum + } + + override def toString: String = s"alldifferent_except_0($expr)" + override def evalToString: String = s"alldifferent_except_0(${expr.evalToString})" + + override def signature: Signature = Signature(inputs = List(ListIntType), output = BoolType) +} + + +object AllDifferentExceptZeroExpression { + object Factory extends Creatable with AutoNamed { + + override def templateSignature: Signature = + Signature(inputs = List(ListIntType), output = BoolType) + + override def create(children: List[Expression[?]]): Expression[?] = { + require(children.length == 1, "AllDifferentExceptZero requires one child (list).") + + val listExpr = children.head.asInstanceOf[Expression[List[Integer]]] + AllDifferentExceptZeroExpression(listExpr) + } + + override def ownerClass: Class[_] = AllDifferentExceptZeroExpression.getClass + } +} + +case class ArgSortExpression( + listExpr: Expression[List[Integer]] + ) extends Expression[List[Integer]] with ComposableExpression { + + override def children: List[Expression[?]] = List(listExpr) + + override def withNewChildren(newChildren: List[Expression[?]]): Expression[?] = { + require(newChildren.length == 1) + ArgSortExpression(newChildren.head.asInstanceOf[Expression[List[Integer]]]) + } + + override def eval(context: Map[String, Any]): List[Integer] = { + val list = listExpr.eval(context) + + list.zipWithIndex + .sortBy(_._1.intValue()) + .map { case (_, index) => (index + 1).asInstanceOf[Integer] } + } + + override def toString: String = s"arg_sort($listExpr)" + override def evalToString: String = s"arg_sort(${listExpr.evalToString})" + + override def signature: Signature = Signature(inputs = List(ListIntType), output = ListIntType) +} + +object ArgSortExpression { + object Factory extends Creatable with AutoNamed { + + override def templateSignature: Signature = + Signature(inputs = List(ListIntType), output = ListIntType) + + override def create(children: List[Expression[?]]): Expression[?] = { + require(children.length == 1, "ArgSort requires one child (list).") + + val listExpr = children.head.asInstanceOf[Expression[List[Integer]]] + ArgSortExpression(listExpr) + } + + override def ownerClass: Class[_] = ArgSortExpression.getClass + } +} + + +case class SymmetryBreakingExpression( + constraint: Expression[Boolean] + ) extends Expression[Boolean] with ComposableExpression { + + override def children: List[Expression[?]] = List(constraint) + + override def withNewChildren(newChildren: List[Expression[?]]): Expression[?] = { + require(newChildren.length == 1) + SymmetryBreakingExpression(newChildren.head.asInstanceOf[Expression[Boolean]]) + } + + override def eval(context: Map[String, Any]): Boolean = constraint.eval(context) + + override def distance(context: Map[String, Any]): Int = constraint.distance(context) + + override def toString: String = s"symmetry_breaking_constraint($constraint)" + override def evalToString: String = s"symmetry_breaking_constraint(${constraint.evalToString})" + + override def signature: Signature = constraint.signature +} + +object SymmetryBreakingExpression { + object Factory extends Creatable with AutoNamed { + + override def templateSignature: Signature = + Signature(inputs = List(BoolType), output = BoolType) + + override def create(children: List[Expression[?]]): Expression[?] = { + require(children.length == 1, "SymmetryBreaking requires one child (boolean).") + + val boolExpr = children.head.asInstanceOf[Expression[Boolean]] + SymmetryBreakingExpression(boolExpr) + } + + override def ownerClass: Class[_] = SymmetryBreakingExpression.getClass + } +} diff --git a/src/main/scala/com/beepboop/app/components/Operator.scala b/src/main/scala/com/beepboop/app/components/Operator.scala index e1a4594..0342f3b 100644 --- a/src/main/scala/com/beepboop/app/components/Operator.scala +++ b/src/main/scala/com/beepboop/app/components/Operator.scala @@ -73,9 +73,9 @@ case class MulOperator[T: ClassTag]()(implicit strategy: Multiplicable[T]) exten } } -case class DivOperator[T: ClassTag]()(implicit strategy: Divisible[T]) extends BinaryOperator[T] { +case class DivOperator[T: ClassTag]()(implicit strategy: Divisible[T]) extends BinaryOperator[T], LogTrait { override def eval(left: Any, right: Any): T = { - strategy.div(left.asInstanceOf[T], right.asInstanceOf[T]) + strategy.div(left.asInstanceOf[T], right.asInstanceOf[T]) } override def toString: String = "/" @@ -118,22 +118,7 @@ case class EqualOperator[T: ClassTag]()(implicit strategy: Equatable[T]) extends } override def distance(left: Any, right: Any): Int = { - val leftT = left.asInstanceOf[T] - val rightT = right.asInstanceOf[T] - - if (classOf[Number].isAssignableFrom(classTag[T].runtimeClass)) { - Math.abs(leftT.asInstanceOf[Number].intValue() - rightT.asInstanceOf[Number].intValue()) - } else if (classOf[Set[?]].isAssignableFrom(classTag[T].runtimeClass)) { - try { - val s1 = leftT.asInstanceOf[Set[Any]] - val s2 = rightT.asInstanceOf[Set[Any]] - (s1 diff s2).size + (s2 diff s1).size - } catch { - case _: Exception => if (strategy.equal(leftT, rightT)) 0 else 1 - } - } else { - if (strategy.equal(leftT, rightT)) 0 else 1 - } + strategy.distance(left.asInstanceOf[T], right.asInstanceOf[T]) } } @@ -181,12 +166,7 @@ case class LessOperator[T: ClassTag]()(implicit strategy: LessThan[T]) extends B override def distance(left: Any, right: Any): Int = { val leftT = left.asInstanceOf[T] val rightT = right.asInstanceOf[T] - - if (leftT.isInstanceOf[Integer] && rightT.isInstanceOf[Integer]) { - Math.abs(leftT.asInstanceOf[Integer] - (rightT.asInstanceOf[Integer] - 1)) - } else { - if (strategy.less(leftT, rightT)) 0 else 1 - } + strategy.distance(leftT, rightT) } } @@ -201,11 +181,7 @@ case class LessEqualOperator[T: ClassTag]()(implicit strategy: LessEqual[T]) ext override def distance(left: Any, right: Any): Int = { val leftT = left.asInstanceOf[T] val rightT = right.asInstanceOf[T] - if (leftT.isInstanceOf[Integer] && rightT.isInstanceOf[Integer]) { - Math.abs(leftT.asInstanceOf[Integer] - rightT.asInstanceOf[Integer]) - } else { - if (strategy.lessEqual(leftT, rightT)) 0 else 1 - } + strategy.distance(leftT, rightT) } } @@ -226,11 +202,7 @@ case class GreaterOperator[T: ClassTag]()(implicit strategy: GreaterThan[T]) ext override def distance(left: Any, right: Any): Int = { val leftT = left.asInstanceOf[T] val rightT = right.asInstanceOf[T] - if (leftT.isInstanceOf[Integer] && rightT.isInstanceOf[Integer]) { - Math.abs(leftT.asInstanceOf[Integer] - (rightT.asInstanceOf[Integer] + 1)) - } else { - if (strategy.greater(leftT, rightT)) 0 else 1 - } + strategy.distance(leftT, rightT) } } @@ -251,11 +223,7 @@ case class GreaterEqualOperator[T: ClassTag]()(implicit strategy: GreaterEqual[T override def distance(left: Any, right: Any): Int = { val leftT = left.asInstanceOf[T] val rightT = right.asInstanceOf[T] - if (leftT.isInstanceOf[Integer] && rightT.isInstanceOf[Integer]) { - Math.abs(leftT.asInstanceOf[Integer] - rightT.asInstanceOf[Integer]) - } else { - if (strategy.greaterEqual(leftT, rightT)) 0 else 1 - } + strategy.distance(leftT,rightT) } } diff --git a/src/main/scala/com/beepboop/app/components/Relational.scala b/src/main/scala/com/beepboop/app/components/Relational.scala index 09512b6..1aba74f 100644 --- a/src/main/scala/com/beepboop/app/components/Relational.scala +++ b/src/main/scala/com/beepboop/app/components/Relational.scala @@ -1,27 +1,46 @@ package com.beepboop.app.components -trait Equatable[T] extends Serializable{ + +trait Equatable[T] extends Serializable { def equal(a: T, b: T): Boolean + def distance(a: T, b: T): Int } object Equatable { implicit object IntIsEquatable extends Equatable[Integer] { override def equal(a: Integer, b: Integer): Boolean = a == b + + override def distance(a: Integer, b: Integer): Int = Math.abs(a - b) } implicit object BoolIsEquatable extends Equatable[Boolean] { override def equal(a: Boolean, b: Boolean): Boolean = a == b + + override def distance(a: Boolean, b: Boolean): Int = if (a == b) 0 else 1 } implicit object ListIntIsEquatable extends Equatable[List[Integer]] { - override def equal(a: List[Integer], b: List[Integer]): Boolean = a == b + override def equal(a: List[Integer], b: List[Integer]): Boolean = { + require(a.size == b.size, "arrays must be equal length") + a == b + } + override def distance(a: List[Integer], b: List[Integer]): Int = { + if (a.size != b.size) { + 1000 + Math.abs(a.size - b.size) + } else { + a.zip(b).map { case (v1, v2) => Math.abs(v1 - v2) }.sum + } + } } implicit object SetIntIsEquatable extends Equatable[Set[Integer]] { override def equal(a: Set[Integer], b: Set[Integer]): Boolean = a == b + + override def distance(a: Set[Integer], b: Set[Integer]): Int = { + (a diff b).size + (b diff a).size + } } } - trait NotEquatable[T] extends Serializable{ def notEqual(a: T, b: T): Boolean def distance(a: T, b: T): Int @@ -41,12 +60,20 @@ object NotEquatable { } implicit object ListIntIsNotEquatable extends NotEquatable[List[Integer]] { - override def notEqual(a: List[Integer], b: List[Integer]): Boolean = a != b + override def notEqual(a: List[Integer], b: List[Integer]): Boolean = { + require(a.size == b.size, "arrays must be equal length") + a != b + } override def distance(a: List[Integer], b: List[Integer]): Int = { - if (a != b) 0 else 1 + require(a.size == b.size, "arrays must be equal length") + if (a != b) { + val diffCount = a.zip(b).count { case (i, j) => i != j } + diffCount - 1 + } else { + 1 + } } } - implicit object SetIntIsNotEquatable extends NotEquatable[Set[Integer]] { override def notEqual(a: Set[Integer], b: Set[Integer]): Boolean = a != b override def distance(a: Set[Integer], b: Set[Integer]): Int = { @@ -65,9 +92,15 @@ implicit object ListIntContainsInt extends Contains[List[Integer], Integer] { override def contains(left: List[Integer], right: Integer): Boolean = left.contains(right) override def distance(left: List[Integer], right: Integer): Int = { - if (left.isEmpty) return Int.MaxValue - if (left.contains(right)) return 0 - left.map(x => (x - right).abs).min + if (left.isEmpty) { + return 1000 + } + val occurrences = left.count(_ == right) + if (occurrences > 0) { + occurrences - 1 + } else { + left.map(x => (x - right).abs).min + } } } @@ -83,41 +116,70 @@ implicit object SetIntContainsInt extends Contains[Set[Int], Int] { trait LessThan[T] extends Serializable{ def less(a: T, b: T): Boolean + def distance(a: T, b: T): Integer } object LessThan { implicit object IntIsLessThan extends LessThan[Integer] { override def less(a: Integer, b: Integer): Boolean = a < b + + override def distance(a: Integer, b: Integer): Integer = { + if (a < b) { + (b - 1) - a + } else { + (a - b) + 1 + } + } + } } -trait GreaterThan[T] extends Serializable{ +trait GreaterThan[T] extends Serializable { def greater(a: T, b: T): Boolean + + def distance(a: T, b: T): Integer } object GreaterThan { implicit object IntIsGreaterThan extends GreaterThan[Integer] { override def greater(a: Integer, b: Integer): Boolean = a > b + + override def distance(a: Integer, b: Integer): Integer = { + if (a > b) { + a - (b + 1) + } else { + (b + 1) - a + } + } } } + trait LessEqual[T] extends Serializable{ def lessEqual(a: T, b: T): Boolean + def distance(a: T, b: T): Integer } object LessEqual { implicit object IntIsLessEqual extends LessEqual[Integer] { override def lessEqual(a: Integer, b: Integer): Boolean = a <= b + override def distance(a: Integer, b: Integer): Integer = { + Math.abs(a - b) + } } } trait GreaterEqual[T] extends Serializable{ def greaterEqual(a: T, b: T): Boolean + def distance(a: T, b: T): Integer } object GreaterEqual { implicit object IntIsGreaterEqual extends GreaterEqual[Integer] { override def greaterEqual(a: Integer, b: Integer): Boolean = a >= b + override def distance(a: Integer, b: Integer): Integer = { + Math.abs(a - b) + } } } diff --git a/src/main/scala/com/beepboop/app/cpicker/ConstraintPicker.scala b/src/main/scala/com/beepboop/app/cpicker/ConstraintPicker.scala index 25f87d6..f63be24 100644 --- a/src/main/scala/com/beepboop/app/cpicker/ConstraintPicker.scala +++ b/src/main/scala/com/beepboop/app/cpicker/ConstraintPicker.scala @@ -24,11 +24,8 @@ object ExpressionOrdering extends Ordering[Expression[?]] { } object ConstraintPicker extends LogTrait { - var config: AppConfig = null + var config: AppConfig = AppConfig.get - def setConfig(config: AppConfig): Unit = { - this.config = config - } private def order(item: ConstraintData): Double = { item.solCount.toDouble * item.distributionScore @@ -49,7 +46,7 @@ object ConstraintPicker extends LogTrait { initialWorkload.foreach { tmpNode => val expr = tmpNode.constraint - val depth = expr.exprDepth + val depth = expr.depth val symbols = expr.symbolCount val distScore = DistributionScorer.scoreNormal(symbols, distMean, distStd) @@ -99,7 +96,7 @@ object ConstraintPicker extends LogTrait { batchWorkload.foreach { group => - val totalDepth = group.map(_.exprDepth).sum + val totalDepth = group.map(_.depth).sum val totalSymbols = group.map(_.symbolCount).sum val avgDistScore = group.map(expr => diff --git a/src/main/scala/com/beepboop/app/cpicker/ConstraintSaver.scala b/src/main/scala/com/beepboop/app/cpicker/ConstraintSaver.scala index 5b918a8..cdb8769 100644 --- a/src/main/scala/com/beepboop/app/cpicker/ConstraintSaver.scala +++ b/src/main/scala/com/beepboop/app/cpicker/ConstraintSaver.scala @@ -9,11 +9,8 @@ import java.io.File object ConstraintSaver { - var config: AppConfig = null + var config: AppConfig = AppConfig.get - def setConfig(config: AppConfig): Unit = { - this.config = config - } def save(constraints: Expression[?]*): Path = { val tempPath: Path = Files.createTempFile("mzn_temp_", ".mzn") tempPath.toFile.deleteOnExit(); diff --git a/src/main/scala/com/beepboop/app/dataprovider/ConfigLoader.scala b/src/main/scala/com/beepboop/app/dataprovider/ConfigLoader.scala index b380b2e..7085da0 100644 --- a/src/main/scala/com/beepboop/app/dataprovider/ConfigLoader.scala +++ b/src/main/scala/com/beepboop/app/dataprovider/ConfigLoader.scala @@ -20,33 +20,3 @@ case class AlgorithmConfig( mutations: List[Mutation], logging: LogConfig ) derives ConfigReader - -object ConfigLoader { - import pureconfig.module.yaml.* - - private var _settings: Option[AlgorithmConfig] = None - - def initialize(path: String): Unit = { - val loaded = YamlConfigSource.file(path).loadOrThrow[AlgorithmConfig] - _settings = Some(loaded) - println(s"Configuration loaded from: $path") - } - - def settings: AlgorithmConfig = { - _settings match { - case Some(s) => s - case None => - println("WARNING: ConfigLoader not initialized explicitly. Loading default 'config.yaml'.") - initialize("config.yaml") - _settings.get - } - } - - def getWeight(componentName: String): Double = { - settings.expressionWeights.getOrElse(componentName, 0.0) - } - - def getClassLogConfig(className: String): ClassLogConfig = { - settings.logging.classes.getOrElse(className.stripSuffix("$"), ClassLogConfig()) - } -} \ No newline at end of file diff --git a/src/main/scala/com/beepboop/app/dataprovider/DataImporter.scala b/src/main/scala/com/beepboop/app/dataprovider/DataImporter.scala index 953ef56..988f0cf 100644 --- a/src/main/scala/com/beepboop/app/dataprovider/DataImporter.scala +++ b/src/main/scala/com/beepboop/app/dataprovider/DataImporter.scala @@ -1,40 +1,36 @@ package com.beepboop.app.dataprovider - import scala.io.Source import scala.util.{Failure, Success, Try} import scala.language.postfixOps import spray.json.* -import DefaultJsonProtocol.* import com.beepboop.app.MinizincDznVisitor import com.beepboop.app.logger.LogTrait -import com.beepboop.parser.NewMinizincParserBaseListener import org.antlr.v4.runtime.{CharStreams, CommonTokenStream} import com.beepboop.parser.{NewMinizincLexer, NewMinizincParser} import scala.collection.JavaConverters.asScalaBufferConverter -import scala.collection.JavaConverters.collectionAsScalaIterableConverter -import scala.collection.JavaConverters.iterableAsScalaIterableConverter object DataImporter extends DefaultJsonProtocol, LogTrait { - def prepareSets(data: String, dataItems: List[DataItem]): Unit = { - dataItems.filter(_.detailedDataType.isSet).foreach { item => + dataItems.filter(i => i.detailedDataType != null && i.detailedDataType.isSet && (i.value == null || i.value == None)).foreach { item => val expr = item.expr - if (expr != "") { + if (expr != null && expr.trim.nonEmpty) { try { - val parts = parseRange(expr) + if (expr.contains("..")) { + val parts = parseRange(expr) - val minVal = resolveBound(parts(0), dataItems) - val maxVal = resolveBound(parts(1), dataItems) + val minVal = resolveBound(parts(0), dataItems) + val maxVal = resolveBound(parts(1), dataItems) - val setValues = (minVal to maxVal).toList - item.value = setValues + val setValues = (minVal to maxVal).toList + item.value = setValues - info(s"Set ${item.name} resolved: $minVal..$maxVal (size: ${setValues.size})") + info(s"Set ${item.name} resolved: $minVal..$maxVal (size: ${setValues.size})") + } } catch { case e: Exception => warn(s"Failed to resolve set '${item.name}' with expr '$expr': ${e.getMessage}") @@ -50,18 +46,40 @@ object DataImporter extends DefaultJsonProtocol, LogTrait { return raw.toInt } - val sumPattern = "sum\\((.*)\\)".r + val SumPattern = """sum\s*\(\s*(.*)\s*\)""".r + val MinPattern = """min\s*\(\s*(.*)\s*\)""".r + val MaxPattern = """max\s*\(\s*(.*)\s*\)""".r + raw match { - case sumPattern(varName) => - val listValue = lookupDependency(varName, dataItems) - listValue match { - case l: List[_] => l.map(_.toString.toInt).sum - case _ => throw new IllegalArgumentException(s"Item '$varName' is not a list, cannot calculate sum.") - } + case SumPattern(varName) => + getAsIntList(varName, dataItems).sum + + case MinPattern(varName) => + val list = getAsIntList(varName, dataItems) + if (list.nonEmpty) list.min else 0 + + case MaxPattern(varName) => + val list = getAsIntList(varName, dataItems) + if (list.nonEmpty) list.max else 0 case varName => - val value = lookupDependency(varName, dataItems) - value.toString.toInt + lookupDependency(varName, dataItems) match { + case i: Int => i + case s: String => Try(s.toInt).getOrElse(throw new IllegalArgumentException(s"Value '$s' for '$varName' is not an integer.")) + case other => throw new IllegalArgumentException(s"Dependency '$varName' is not a scalar integer (found ${other.getClass.getSimpleName}).") + } + } + } + + private def getAsIntList(name: String, dataItems: List[DataItem]): List[Int] = { + lookupDependency(name, dataItems) match { + case l: List[_] => l.map { + case i: Int => i + case s: String => Try(s.toInt).getOrElse(0) + case d: Double => d.toInt + case other => throw new IllegalArgumentException(s"List element in '$name' is not numeric: $other") + } + case other => throw new IllegalArgumentException(s"Dependency '$name' is not a list (found ${other.getClass.getSimpleName}).") } } diff --git a/src/main/scala/com/beepboop/app/dataprovider/DataProvider.scala b/src/main/scala/com/beepboop/app/dataprovider/DataProvider.scala index b09be67..ca16401 100644 --- a/src/main/scala/com/beepboop/app/dataprovider/DataProvider.scala +++ b/src/main/scala/com/beepboop/app/dataprovider/DataProvider.scala @@ -150,7 +150,7 @@ object DataProvider extends LogTrait { DataImporter.importDataFile(dPath, instanceParams) val paramMap: Map[String, Any] = instanceParams - .filter(_.value != null) + .filter(p => p.value != null && p.value != None) .map(p => p.name -> p.value) .toMap @@ -174,6 +174,26 @@ object DataProvider extends LogTrait { info(s"Total: Loaded $solutionCount contexts across ${dataPaths.size} instances.") + info("==================== DATA LOADING SUMMARY ====================") + val previewCtx = if (solutionContexts.nonEmpty) solutionContexts.head else Map.empty + + (variables ++ parameters).sortBy(_.name).foreach { item => + val internalKind = if (item.isVar) "VAR" else "PAR" + val internalType = getExpressionType(item) + + val valuePreview = previewCtx.get(item.name) match { + case Some(l: List[_]) => + if (l.nonEmpty && l.head.isInstanceOf[Set[_]]) s"List[Set] (size=${l.size}, first=${l.head})" + else s"List (size=${l.size})" + case Some(s: Set[_]) => s"Set (size=${s.size})" + case Some(v) => v.toString + case None => "MISSING / NONE" + } + + info(f"[$internalKind] ${item.name}%-20s | Type: ${internalType.toString}%-15s | Val: $valuePreview") + } + info("==============================================================\n") + initializeCreatables(solutionContexts.head, modelParamsSchema) } @@ -194,10 +214,25 @@ object DataProvider extends LogTrait { } def getExpressionType(item: DataItem): ExpressionType = { - val typeStr = Option(item.detailedDataType).map(_.toString.toLowerCase).getOrElse("") - if (typeStr.contains("set") && typeStr.contains("array")) ListSetIntType - else if (typeStr.contains("array")) ListIntType - else IntType + val typeStr = Option(item.detailedDataType) + .map(_.dataType.toLowerCase) + .getOrElse(item.dataType.toLowerCase) + + val isArray = Option(item.detailedDataType).exists(_.isArray) + val isSet = Option(item.detailedDataType).exists(_.isSet) + + if (isArray && isSet) ListSetIntType + else if (isArray) ListIntType + else if (isSet) ListIntType + else if (typeStr.contains("bool")) BoolType + else if (typeStr.contains("int") || typeStr.contains("..")) IntType + else { + solutionContexts.headOption.flatMap(_.get(item.name)) match { + case Some(_: Int) | Some(_: java.lang.Integer) => IntType + case Some(_: Boolean) | Some(_: java.lang.Boolean) => BoolType + case _ => UnknownType + } + } } private def inferTypeFromValue(value: Any): ExpressionType = value match { diff --git a/src/main/scala/com/beepboop/app/logger/LogTrait.scala b/src/main/scala/com/beepboop/app/logger/LogTrait.scala index 58cb94b..2c720a0 100644 --- a/src/main/scala/com/beepboop/app/logger/LogTrait.scala +++ b/src/main/scala/com/beepboop/app/logger/LogTrait.scala @@ -2,11 +2,11 @@ package com.beepboop.app.logger // third party import com.beepboop.app.dataprovider.ClassLogConfig +import com.beepboop.app.utils.AppConfig import org.slf4j.LoggerFactory import com.typesafe.scalalogging.Logger // own import com.beepboop.app.logger.* -import com.beepboop.app.dataprovider.ConfigLoader import com.typesafe.scalalogging.Logger import org.slf4j.LoggerFactory @@ -15,11 +15,12 @@ trait LogTrait { @transient protected lazy val logger: Logger = Logger(LoggerFactory.getLogger(getClass.getName)) - private lazy val globalLogEnable: Boolean = ConfigLoader.settings.logging.enabled - private lazy val logConfigDebug: Boolean = ConfigLoader.settings.logging.logDebug + private lazy val globalLogEnable: Boolean = AppConfig.algorithm.logging.enabled + private lazy val logConfigDebug: Boolean = AppConfig.algorithm.logging.logDebug + private lazy val logConfig: ClassLogConfig = { - val tLogConfig = ConfigLoader.getClassLogConfig(getClass.getName) + val tLogConfig = AppConfig.getClassLogConfig(getClass.getName) if (logConfigDebug) { logger.debug(s"Config for: ${getClass.getName} | level: ${tLogConfig.level} | gEnable: $globalLogEnable | enabled: ${tLogConfig.enabled}") } diff --git a/src/main/scala/com/beepboop/app/mutations/ExpressionGenerator.scala b/src/main/scala/com/beepboop/app/mutations/ExpressionGenerator.scala index ac8ef35..e1bb3df 100644 --- a/src/main/scala/com/beepboop/app/mutations/ExpressionGenerator.scala +++ b/src/main/scala/com/beepboop/app/mutations/ExpressionGenerator.scala @@ -1,25 +1,27 @@ package com.beepboop.app.mutations /* own modules */ import com.beepboop.app.components.* -import com.beepboop.app.dataprovider.ConfigLoader +import com.beepboop.app.utils.AppConfig import com.beepboop.app.logger.LogTrait case class GenerationContext( - variables: Map[ExpressionType, List[String]] = Map.empty + variables: Map[ExpressionType, List[String]] = Map.empty, ) { - def withVariable(name: String, varType: ExpressionType): GenerationContext = { + def withVariable(name: String, varType: ExpressionType, domain: Expression[?]): GenerationContext = { val existingNames = variables.getOrElse(varType, List.empty) - copy(variables = variables + (varType -> (name :: existingNames))) + copy( + variables = variables + (varType -> (name :: existingNames)), + ) } } - object ExpressionGenerator extends LogTrait { def generate( requiredType: ExpressionType, maxDepth: Int, - ctx: GenerationContext = GenerationContext() + ctx: GenerationContext = GenerationContext(), + exclude: Option[Class[_]] = None ): Option[Expression[?]] = { val registry = ComponentRegistry @@ -53,55 +55,84 @@ object ExpressionGenerator extends LogTrait { return None } - val chosenCreatable = selectWeighted(possibleCreatables) + + val filteredPool = exclude match { + case Some(cls) => + val excludeName = cls.getSimpleName.stripSuffix("$") + possibleCreatables.filterNot(_.toString == excludeName) + case None => + possibleCreatables + } + //println(exclude.toString) + //println(filteredPool.mkString(",")) + + + //val finalPool = if (filteredPool.nonEmpty) filteredPool else possibleCreatables + + + val chosenCreatable = selectWeighted(possibleCreatables, ctx) debug(s"Selected: ${chosenCreatable.getClass.getName}") - chosenCreatable match { + val result = chosenCreatable match { case scoped: ContextAwareCreatable => scoped.generateExpression(ctx, (t, c) => generate(t, maxDepth - 1, c)) case standard => val inputs = standard.templateSignature.inputs - val children = inputs.map(inputType => generate(inputType, maxDepth - 1, ctx)) // Pass same ctx - + val children = inputs.map(inputType => generate(inputType, maxDepth - 1, ctx)) if (children.forall(_.isDefined)) { - Some(standard.create(children.flatten)) + val flattenedChildren = children.flatten + val actualChildTypes = flattenedChildren.map(_.signature.output) + if (actualChildTypes == inputs) { + Some(standard.create(flattenedChildren)) + } else { + debug(s"Type mismatch during standard creation: Expected $inputs but got $actualChildTypes") + None + } } else { None } } + result.foreach { expr => + expr.creatorInfo = s"Created by [${chosenCreatable.toString}] at Depth [$maxDepth] for Type [$requiredType] - ${result.get.toString}. " + } + result } - private def selectWeighted(candidates: List[Creatable]): Creatable = { + private def selectWeighted(candidates: List[Creatable], ctx: GenerationContext): Creatable = { val candidatesWithWeights = candidates.map { c => val name = c.toString - val weight = c match { - case _: RandomVariableFactory => 50.0 - case _ => ConfigLoader.settings.expressionWeights.getOrElse(name, 0.0) + case factory: RandomVariableFactory => + val localNames = ctx.variables.getOrElse(factory.varType, Nil) + + if (factory.availableNames.exists(localNames.contains)) { + 800.0 + } else { + 50.0 + } + case _ => + AppConfig.getWeight(name) } (c, weight) } - val totalWeight = candidatesWithWeights.map(_._2).sum - if (totalWeight <= 0) return candidates(scala.util.Random.nextInt(candidates.length)) + if (totalWeight <= 0) { + candidates(scala.util.Random.nextInt(candidates.length)) + } else { + val randomValue = scala.util.Random.nextDouble() * totalWeight - val randomValue = scala.util.Random.nextDouble() * totalWeight + val result = candidatesWithWeights.scanLeft((0.0, None: Option[Creatable])) { + case ((acc, _), (creatable, weight)) => (acc + weight, Some(creatable)) + }.find { case (cumulative, _) => cumulative > randomValue } - var cumulativeWeight = 0.0 - for ((creatable, weight) <- candidatesWithWeights) { - cumulativeWeight += weight - if (randomValue < cumulativeWeight) { - return creatable - } + result.flatMap(_._2).getOrElse(candidates.last) } - - candidates.last } private def sequence[T](opts: List[Option[T]]): Option[List[T]] = @@ -117,4 +148,4 @@ trait ContextAwareCreatable extends Creatable { ctx: GenerationContext, recurse: (ExpressionType, GenerationContext) => Option[Expression[?]] ): Option[Expression[?]] -} \ No newline at end of file +} diff --git a/src/main/scala/com/beepboop/app/mutations/Mutation.scala b/src/main/scala/com/beepboop/app/mutations/Mutation.scala index 7bdc285..3649a20 100644 --- a/src/main/scala/com/beepboop/app/mutations/Mutation.scala +++ b/src/main/scala/com/beepboop/app/mutations/Mutation.scala @@ -1,14 +1,16 @@ package com.beepboop.app.mutations +import com.beepboop.app.components.{ArrayElement, ComposableExpression, IntType} import pureconfig.ConfigReader +import scala.reflect.ClassTag + /* own modules */ import com.beepboop.app.components import com.beepboop.app.components.{BoolType, ComponentRegistry, Constant, DiffnExpression, Expression, ListIntType, OperatorContainer, RandomVariableFactory, Variable} import com.beepboop.app.dataprovider.DataProvider import com.beepboop.app.logger.LogTrait -// Options to keep original names and use 'type' as the discriminator sealed trait Mutation extends LogTrait derives ConfigReader { def enabled: Boolean @@ -137,7 +139,119 @@ case class ReplaceSubtree(val maxDepth: Int, enabled: Boolean = true) extends Mu override def apply(expression: Expression[?], ctx: GenerationContext): List[Expression[?]] = { val requiredType = expression.signature.output - val generated = ExpressionGenerator.generate(requiredType, maxDepth = maxDepth, ctx) + val generated = ExpressionGenerator.generate(requiredType, maxDepth = maxDepth, ctx, Some(expression.getClass)) generated.toList } +} + +case class IteratorCoupling(enabled: Boolean = true) extends Mutation { + override def name: String = "IteratorCoupling" + + override def apply(expression: Expression[?], ctx: GenerationContext): List[Expression[?]] = { + expression match { + case ae: ArrayElement[t] => + implicit val tag: ClassTag[t] = ae.ct + val localInts = ctx.variables.getOrElse(IntType, Nil) + + localInts.map { varName => + ae.copy(index = Variable[Integer](varName)) + } + case _ => Nil + } + } +} + +case class UnwrapExpression(enabled: Boolean = true) extends Mutation { + override def name: String = "UnwrapExpression" + + override def apply(expression: Expression[?], ctx: GenerationContext): List[Expression[?]] = { + expression match { + case c: ComposableExpression => + val compatibleChildren = c.children.filter(_.signature.output == expression.signature.output) + + compatibleChildren + case _ => Nil + } + } +} + + +case class PruneToLeaf(enabled: Boolean = true) extends Mutation { + override def name: String = "PruneToLeaf" + + override def apply(expression: Expression[?], ctx: GenerationContext): List[Expression[?]] = { + expression match { + case c: ComposableExpression if c.children.nonEmpty => + val requiredType = expression.signature.output + + val simpleReplacements = ComponentRegistry.findCreatablesReturning(requiredType) + .filter(_.templateSignature.inputs.isEmpty) + .flatMap { factory => + factory match { + case v: RandomVariableFactory => + v.availableNames.map(n => v.createWithName(n)) + case f => Some(f.create(Nil)) + } + } + + scala.util.Random.shuffle(simpleReplacements).take(3).toList + + case _ => Nil + } + } +} + + +case class InjectQuantifier(enabled: Boolean = true) extends Mutation { + override def name: String = "InjectQuantifier" + + override def apply(expression: Expression[?], ctx: GenerationContext): List[Expression[?]] = { + if (expression.signature.output != components.BoolType) return Nil + + val factories = List( + components.ForAllExpression.ForAllIntListFactory, + //components.ExistsExpression.ExistsIntListFactory + ) + + factories.flatMap { factory => + factory.generateExpression(ctx, (t, c) => + ExpressionGenerator.generate(t, maxDepth = 3, c) + ) + } + } +} + +case class InjectNestedLoop(enabled: Boolean = true) extends Mutation { + override def name: String = "InjectNestedLoop" + + override def apply(expression: Expression[?], ctx: GenerationContext): List[Expression[?]] = { + expression match { + case outerLoop: components.ForAllExpression[_] => + val outerIter = outerLoop.iteratorDef.variableName + val outerType = outerLoop.iteratorDef.collection.signature.output + + val candidateCollections = DataProvider.variables.filter { v => + v.detailedDataType.isArray && + v.detailedDataType.dataType == "int" + } + + if (candidateCollections.isEmpty) return Nil + + val targetCollName = candidateCollections(scala.util.Random.nextInt(candidateCollections.size)).name + + val dependentCollection = components.ArrayElement[List[Integer]]( + components.Variable[List[List[Integer]]](targetCollName), + components.Variable[Integer](outerIter) + ) + + val innerIterName = com.beepboop.app.dataprovider.VarNameGenerator.generateUniqueName("inner") + val innerIterator = components.IteratorDef(innerIterName, dependentCollection) + + val newInnerLoop = components.ForAllExpression(innerIterator, outerLoop.body) + + List(components.ForAllExpression(outerLoop.iteratorDef, newInnerLoop)) + + case _ => Nil + } + } } \ No newline at end of file diff --git a/src/main/scala/com/beepboop/app/mutations/MutationEngine.scala b/src/main/scala/com/beepboop/app/mutations/MutationEngine.scala index bf29e2b..6191a7b 100644 --- a/src/main/scala/com/beepboop/app/mutations/MutationEngine.scala +++ b/src/main/scala/com/beepboop/app/mutations/MutationEngine.scala @@ -1,8 +1,9 @@ package com.beepboop.app.mutations -import pureconfig._ -import pureconfig.module.yaml._ -import pureconfig.generic.derivation.default._ +import com.beepboop.app.utils.AppConfig +import pureconfig.* +import pureconfig.module.yaml.* +import pureconfig.generic.derivation.default.* /* own modules */ import com.beepboop.app.components.* @@ -10,11 +11,7 @@ import com.beepboop.app.logger.LogTrait object AllMutations { - val allLoaded: List[Mutation] = YamlConfigSource.file("config.yaml") - .at("mutations") - .loadOrThrow[List[Mutation]] - - val mutations = allLoaded.filter(_.enabled) + val mutations = AppConfig.enabledMutations val directory: Map[String, Mutation] = mutations.map(m => m.name -> m).toMap @@ -70,7 +67,13 @@ class MutationEngine(val activeMutations: List[Mutation]) extends LogTrait { case _ => UnknownType } - val innerCtx = ctx.withVariable(f.iteratorDef.variableName, innerType) + val domainExpr = f.iteratorDef.collection + + val innerCtx = ctx.withVariable( + f.iteratorDef.variableName, + innerType, + domainExpr + ) collectPossibleMutations(f.iteratorDef.collection, ctx) ++ collectPossibleMutations(f.body, innerCtx) @@ -87,6 +90,7 @@ class MutationEngine(val activeMutations: List[Mutation]) extends LogTrait { def replaceNodeInTree(root: Expression[?], target: Expression[?], replacement: Expression[?]): Expression[?] = { if (root eq target) { + replacement.creatorInfo += s" | Mutated by replacement logic" return replacement } @@ -103,4 +107,6 @@ class MutationEngine(val activeMutations: List[Mutation]) extends LogTrait { case other => other } } + + } \ No newline at end of file diff --git a/src/main/scala/com/beepboop/app/policy/Base.scala b/src/main/scala/com/beepboop/app/policy/Base.scala index 20c64bb..2b17a1f 100644 --- a/src/main/scala/com/beepboop/app/policy/Base.scala +++ b/src/main/scala/com/beepboop/app/policy/Base.scala @@ -51,7 +51,7 @@ case class EnsureAnyVarExists() extends GlobalPolicy { case class EnsureSpecificVarExists(targetName: String) extends GlobalPolicy { private var found = false - override def message: String = s"Expr doesn't contain '$targetName' in scope" + override def message: String = s"Expr doesn't contain required variable in scope" override def reset(): Unit = { found = false @@ -108,4 +108,26 @@ case class NoDuplicateVar() extends LocalPolicy { case _ => Compliant } +} + +case class MaxDepth(limit: Int) extends GlobalPolicy { + override def message: String = s"Expression exceeds max depth of $limit" + + override def reset(): Unit = {} + + override def visit(node: Expression[?]): Unit = {} + + + override def isSatisfied: Boolean = true +} + +object DepthChecker { + def exceeds(expr: Expression[?], limit: Int, current: Int = 0): Boolean = { + if (current > limit) return true + expr match { + case c: com.beepboop.app.components.ComposableExpression => + c.children.exists(child => exceeds(child, limit, current + 1)) + case _ => false + } + } } \ No newline at end of file diff --git a/src/main/scala/com/beepboop/app/postprocessor/Postprocessor.scala b/src/main/scala/com/beepboop/app/postprocessor/Postprocessor.scala index 3f7a38f..d402e7f 100644 --- a/src/main/scala/com/beepboop/app/postprocessor/Postprocessor.scala +++ b/src/main/scala/com/beepboop/app/postprocessor/Postprocessor.scala @@ -423,6 +423,7 @@ object Postprocessor { } def simplify[T : ClassTag](expr: Expression[T]): Expression[T] = { + val originalOutput = expr.signature.output val nodeWithSimplifiedChildren: Expression[T] = expr match { case c: ComposableExpression => val newChildren = c.children.map { @@ -435,5 +436,15 @@ object Postprocessor { leaf } Postprocessor.applyAllRules(nodeWithSimplifiedChildren) + + val finalNode = Postprocessor.applyAllRules(nodeWithSimplifiedChildren) + + finalNode.creatorInfo = s"Simplified from: ${expr.getClass.getSimpleName} (Original Creator: ${expr.creatorInfo})" + + if (finalNode.signature.output != originalOutput) { + return nodeWithSimplifiedChildren + } + + finalNode } } \ No newline at end of file diff --git a/src/main/scala/com/beepboop/app/utils/ArgumentParser.scala b/src/main/scala/com/beepboop/app/utils/ArgumentParser.scala index 2604705..ddabe28 100644 --- a/src/main/scala/com/beepboop/app/utils/ArgumentParser.scala +++ b/src/main/scala/com/beepboop/app/utils/ArgumentParser.scala @@ -1,6 +1,11 @@ package com.beepboop.app.utils -import mainargs.{arg, ParserForClass, Flag} +import pureconfig.module.yaml.YamlConfigSource +import pureconfig.ConfigSource +import com.beepboop.app.dataprovider.{AlgorithmConfig, ClassLogConfig} +import com.beepboop.app.mutations.Mutation +import mainargs.{Flag, ParserForClass, arg} +import pureconfig.ConfigSource case class GeneratorConfig( @arg(positional = true, doc = "Path to the .mzn model file") @@ -46,7 +51,10 @@ case class GeneratorConfig( analyzeModel: String = "nothing", @arg(name = "gurobi-license", doc = "Path to gurobi license file") - gurobiLicense: String = "" + gurobiLicense: String = "", + + @arg(name = "config", doc = "Path to config file") + config: String = "config.yaml" ) case class AppConfig( @@ -64,9 +72,42 @@ case class AppConfig( debug: Boolean, gurobiLicense: String, analyze: Boolean, - analyzeModel: String + analyzeModel: String, + config: String, ) + +object AppConfig { + + private var _instance: Option[AppConfig] = None + + def init(config: AppConfig): Unit = { + _instance = Some(config) + } + + def get: AppConfig = _instance.getOrElse( + throw new IllegalStateException("AppConfig not initialized! Call AppConfig.init() in Main.") + ) + + lazy val algorithm: AlgorithmConfig = { + val yamlPath = get.config + YamlConfigSource.file(yamlPath).loadOrThrow[AlgorithmConfig] + } + + lazy val enabledMutations: List[Mutation] = { + algorithm.mutations.filter(_.enabled) + } + + def getWeight(componentName: String): Double = { + algorithm.expressionWeights.getOrElse(componentName, 0.0) + } + + def getClassLogConfig(className: String): ClassLogConfig = { + val cleanName = className.stripSuffix("$") + algorithm.logging.classes.getOrElse(cleanName, ClassLogConfig()) + } +} + object ArgumentParser { private val parser = ParserForClass[GeneratorConfig] @@ -122,7 +163,8 @@ object ArgumentParser { debug = cli.debug.value, gurobiLicense = cli.gurobiLicense, analyze = cli.analyze.value, - analyzeModel = cli.analyzeModel + analyzeModel = cli.analyzeModel, + config = cli.config )) case Left(errorMsg) => diff --git a/src/test/scala/com/beepboop/app/ConstraintsHeuristicSuiteTest.scala b/src/test/scala/com/beepboop/app/ConstraintsHeuristicSuiteTest.scala index e09e66e..7120fa3 100644 --- a/src/test/scala/com/beepboop/app/ConstraintsHeuristicSuiteTest.scala +++ b/src/test/scala/com/beepboop/app/ConstraintsHeuristicSuiteTest.scala @@ -1,10 +1,20 @@ package com.beepboop.app -import com.beepboop.app.components._ -import com.beepboop.app.utils.Implicits._ +import com.beepboop.app.components.* +import com.beepboop.app.utils.{AppConfig, ArgumentParser} +import com.beepboop.app.utils.Implicits.* +import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite -class ConstraintsHeuristicSuiteTest extends AnyFunSuite { +class ConstraintsHeuristicSuiteTest extends AnyFunSuite with BeforeAndAfterAll { + override def beforeAll(): Unit = { + val testArgs = Array(".", "--config", "config.yaml") + + val configObject = ArgumentParser.parse(testArgs).get + + AppConfig.init(configObject) + } + private def i(n: Int): Integer = Integer.valueOf(n) private def c(value: Int): Constant[Integer] = Constant(i(value)) diff --git a/src/test/scala/com/beepboop/app/ExpressionSuiteTest.scala b/src/test/scala/com/beepboop/app/ExpressionSuiteTest.scala index 63b57a4..817d39a 100644 --- a/src/test/scala/com/beepboop/app/ExpressionSuiteTest.scala +++ b/src/test/scala/com/beepboop/app/ExpressionSuiteTest.scala @@ -1,10 +1,20 @@ package com.beepboop.app -import com.beepboop.app.components._ +import com.beepboop.app.components.* +import com.beepboop.app.utils.{AppConfig, ArgumentParser} import org.scalatest.funsuite.AnyFunSuite import com.beepboop.app.utils.Implicits.* +import org.scalatest.BeforeAndAfterAll -class ForAllExpressionSuite extends AnyFunSuite { +class ForAllExpressionSuite extends AnyFunSuite with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + val testArgs = Array(".", "--config", "config.yaml") + + val configObject = ArgumentParser.parse(testArgs).get + + AppConfig.init(configObject) + } test("ForAllExpression should evaluate to true when all elements satisfy the predicate") { // forall(i in [1, 2, 3]))(i > 0) diff --git a/src/test/scala/com/beepboop/app/ModelParsingTest.scala b/src/test/scala/com/beepboop/app/ModelParsingTest.scala index 860f2ab..8602a50 100644 --- a/src/test/scala/com/beepboop/app/ModelParsingTest.scala +++ b/src/test/scala/com/beepboop/app/ModelParsingTest.scala @@ -1,8 +1,12 @@ import com.beepboop.app.ParserUtil +import com.beepboop.app.components.{IntType, ListIntType, ListSetIntType} +import com.beepboop.app.dataprovider.{DataProvider, ModelParser} +import com.beepboop.app.utils.{AppConfig, ArgumentParser} import com.beepboop.parser.{NewMinizincLexer, NewMinizincParser} import org.scalatest.funsuite.AnyFunSuite import org.scalatest.Inspectors.* import org.antlr.v4.runtime.* +import org.scalatest.BeforeAndAfterAll import scala.io.Source @@ -21,7 +25,56 @@ private object ThrowingErrorListener extends BaseErrorListener { } } +class VariableExtractionSuite extends AnyFunSuite with BeforeAndAfterAll { + override def beforeAll(): Unit = { + val testArgs = Array(".", "--config", "config.yaml") + + val configObject = ArgumentParser.parse(testArgs).get + + AppConfig.init(configObject) + } + + test("ModelParser should extract variables and parameters with correct InternalTypes") { + val filename = "src/test/resources/models/accap_test.mzn" + val (parameters, variables) = ModelParser.getDataItems(filename) + val allItems = parameters ++ variables + + val expectedTypes = Map( + "flights" -> IntType, + "FLIGHT" -> ListIntType, + "times" -> IntType, + "TIME" -> ListIntType, + "COUNTER" -> ListIntType, + "AIRLINE" -> ListIntType, + "airlines" -> IntType, + "FA" -> ListSetIntType, + //"ISet" -> ListSetIntType, + "cNum" -> ListIntType, + "xCoor" -> ListIntType, + //"yCoor" -> ListIntType, -- silence test for a while, it evals correctly with solution data + "opDur" -> ListIntType, + "S" -> ListIntType, + "D" -> IntType + ) + + expectedTypes.foreach { case (name, expectedType) => + val itemOpt = allItems.find(_.name == name) + + assert(itemOpt.isDefined, s"Expected item '$name' was not found in the model") + + val item = itemOpt.get + if (item.name == "times") { + println(s"Detailed Data for 'times': isArray=${item.detailedDataType.isArray}, isSet=${item.detailedDataType.isSet}") + } + val actualType = DataProvider.getExpressionType(item) + + assert(actualType == expectedType, + s"Type mismatch for '$name' Expected: $expectedType, but got: $actualType (Raw: ${item.dataType})") + } + } + + } class OverallGrammarParserSuite extends AnyFunSuite { val modelsToTest: List[String] = List.apply( "australia_test.mzn", diff --git a/src/test/scala/com/beepboop/app/PostprocessorSuiteTest.scala b/src/test/scala/com/beepboop/app/PostprocessorSuiteTest.scala index 39b972c..0e44b53 100644 --- a/src/test/scala/com/beepboop/app/PostprocessorSuiteTest.scala +++ b/src/test/scala/com/beepboop/app/PostprocessorSuiteTest.scala @@ -1,13 +1,21 @@ import com.beepboop.app.components.{SubOperator, *} import com.beepboop.app.postprocessor.Postprocessor +import com.beepboop.app.utils.{AppConfig, ArgumentParser} import org.scalatest.funsuite.AnyFunSuite import org.scalatest.Inspectors.* -import org.scalatest.Tag +import org.scalatest.{BeforeAndAfterAll, Tag} -class PostprocessorSuite extends AnyFunSuite { +class PostprocessorSuite extends AnyFunSuite with BeforeAndAfterAll { + override def beforeAll(): Unit = { + val testArgs = Array(".", "--config", "config.yaml") + + val configObject = ArgumentParser.parse(testArgs).get + + AppConfig.init(configObject) + } test("Unary expression should be simplified on constant abs(-2) | JoinConstants") { diff --git a/src/test/scala/com/beepboop/app/SimilaritySuiteTest.scala b/src/test/scala/com/beepboop/app/SimilaritySuiteTest.scala new file mode 100644 index 0000000..e69de29