diff --git a/src/main/scala/mimir/Database.scala b/src/main/scala/mimir/Database.scala index 85a2d2ab..26499ad7 100644 --- a/src/main/scala/mimir/Database.scala +++ b/src/main/scala/mimir/Database.scala @@ -105,6 +105,8 @@ case class Database(backend: RABackend, metadataBackend: MetadataBackend) val ra = new mimir.sql.RAToSql(this) val functions = new mimir.algebra.function.FunctionRegistry val aggregates = new mimir.algebra.function.AggregateRegistry + val types:mimir.algebra.typeregistry.TypeRegistry + = mimir.algebra.typeregistry.DefaultTypeRegistry //// Logic val compiler = new mimir.exec.Compiler(this) @@ -113,11 +115,17 @@ case class Database(backend: RABackend, metadataBackend: MetadataBackend) val typechecker = new mimir.algebra.Typechecker( functions = Some(functions), aggregates = Some(aggregates), - models = Some(models) + models = Some(models), + types = types ) val interpreter = new mimir.algebra.Eval( functions = Some(functions) ) + //// Translation + val gpromTranslator = new mimir.algebra.gprom.OperatorTranslation(this) + val sparkTranslator = new mimir.algebra.spark.OperatorTranslation(this) + + val metadataTables = Seq("MIMIR_ADAPTIVE_SCHEMAS", "MIMIR_MODEL_OWNERS", "MIMIR_MODELS", "MIMIR_VIEWS", "MIMIR_SYS_TABLES", "MIMIR_SYS_ATTRS") /** * Optimize and evaluate the specified query. Applies all Mimir-specific optimizations @@ -239,6 +247,11 @@ case class Database(backend: RABackend, metadataBackend: MetadataBackend) case None => backend.getTableSchema(name) } } + /** + * Look up the raw typed schema for the specified table + */ + def tableBaseSchema(name: String): Option[Seq[(String,BaseType)]] = + tableSchema(name).map { _.map { case (name,t) => (name, types.rootType(t)) } } /** * Build a Table operator for the table with the provided name. @@ -458,8 +471,6 @@ case class Database(backend: RABackend, metadataBackend: MetadataBackend) views.init() lenses.init() adaptiveSchemas.init() - mimir.algebra.gprom.OperatorTranslation(this) - mimir.algebra.spark.OperatorTranslation(this) } /** @@ -492,7 +503,7 @@ case class Database(backend: RABackend, metadataBackend: MetadataBackend) sourceFile: File ) : Unit = loadTable(targetTable, sourceFile, true, ("CSV", Seq(StringPrimitive(","),BoolPrimitive(false),BoolPrimitive(false))), - Some(targetSchema.map(el => (el._1, Type.fromString(el._2))))) + Some(targetSchema.map(el => (el._1, types.rootType(types.fromString(el._2)))))) def loadTable( targetTable: String, @@ -511,7 +522,7 @@ case class Database(backend: RABackend, metadataBackend: MetadataBackend) sourceFile: File, force:Boolean, format:(String, Seq[PrimitiveValue]), - targetSchema: Option[Seq[(String, Type)]] + targetSchema: Option[Seq[(String, BaseType)]] ){ val (delim, typeinference, detectHeaders) = format._1.toUpperCase() match { case "CSV" => { @@ -535,7 +546,7 @@ case class Database(backend: RABackend, metadataBackend: MetadataBackend) targetTable: String, sourceFile: File, force:Boolean, - targetSchema: Option[Seq[(String, Type)]] = None, + targetSchema: Option[Seq[(String, BaseType)]] = None, inferTypes:Boolean = true, detectHeaders:Boolean = true, backendOptions:Map[String, String] = Map(), @@ -565,7 +576,7 @@ case class Database(backend: RABackend, metadataBackend: MetadataBackend) views.create(targetTable.toUpperCase, oper) } else { val schema = targetSchema match { - case None => tableSchema(targetTable) + case None => tableBaseSchema(targetTable) case _ => targetSchema } LoadData.handleLoadTableRaw(this, targetTable.toUpperCase, schema, sourceFile, options, format) diff --git a/src/main/scala/mimir/Mimir.scala b/src/main/scala/mimir/Mimir.scala index a32edecd..3c0938a7 100644 --- a/src/main/scala/mimir/Mimir.scala +++ b/src/main/scala/mimir/Mimir.scala @@ -58,9 +58,9 @@ object Mimir extends LazyLogging { if(!conf.quiet()){ output.print("Connecting to " + conf.backend() + "://" + conf.dbname() + "...") } + sback.sparkTranslator = db.sparkTranslator db.metadataBackend.open() db.backend.open() - OperatorTranslation.db = db sback.registerSparkFunctions(db.functions.functionPrototypes.map(el => el._1).toSeq, db.functions) sback.registerSparkAggregates(db.aggregates.prototypes.map(el => el._1).toSeq, db.aggregates) diff --git a/src/main/scala/mimir/MimirVizier.scala b/src/main/scala/mimir/MimirVizier.scala index 19df4a3a..d3f69d98 100644 --- a/src/main/scala/mimir/MimirVizier.scala +++ b/src/main/scala/mimir/MimirVizier.scala @@ -68,6 +68,7 @@ object MimirVizier extends LazyLogging { val database = Mimir.conf.dbname().split("[\\\\/]").last.replaceAll("\\..*", "") val sback = new SparkBackend(database) db = new Database(sback, new JDBCMetadataBackend(Mimir.conf.backend(), Mimir.conf.dbname())) + sback.sparkTranslator = db.sparkTranslator db.metadataBackend.open() db.backend.open() val otherExcludeFuncs = Seq("NOT","AND","!","%","&","*","+","-","/","<","<=","<=>","=","==",">",">=","^") @@ -335,8 +336,8 @@ object MimirVizier extends LazyLogging { logger.debug(s"loadCSV: From Vistrails: [ $file ] inferTypes: $inferTypes detectHeaders: $detectHeaders format: ${format} -> [ ${backendOptions.mkString(",")} ]") ; val bkOpts = backendOptions.map{ case (optKey:String, optVal:String) => (optKey, optVal) - case hm:java.util.HashMap[String,String] => { - val entry = hm.entrySet().iterator().next() + case hm:java.util.HashMap[_,_] => { + val entry = hm.asInstanceOf[java.util.HashMap[String,String]].entrySet().iterator().next() (entry.getKey, entry.getValue) } case _ => throw new Exception("loadCSV: bad options type") @@ -459,11 +460,15 @@ object MimirVizier extends LazyLogging { val timeRes = logTime("createLens") { logger.debug("createView: From Vistrails: [" + input + "] [" + query + "]" ) ; val (viewNameSuffix, inputSubstitutionQuery) = input match { - case aliases:JMapWrapper[String,String] => { + case rawAliases:JMapWrapper[_,_] => { + // The following line needed to suppress compiler type erasure warning + val aliases = rawAliases.asInstanceOf[JMapWrapper[String,String]] aliases.map{ case (vizierName, mimirName) => db.sql.registerVizierNameMapping(vizierName.toUpperCase(), mimirName) } (aliases.unzip._2.mkString(""), query) } - case inputs:Seq[String] => { + case rawInputs:Seq[_] => { + // The following line needed to suppress compiler type erasure warning + val inputs = rawInputs.asInstanceOf[Seq[String]] (inputs.mkString(""),inputs.zipWithIndex.foldLeft(query)((init, curr) => { init.replaceAll(s"\\{\\{\\s*input_${curr._2}\\s*\\}\\}", curr._1) })) @@ -570,11 +575,15 @@ object MimirVizier extends LazyLogging { def vistrailsQueryMimirJson(input:Any, query : String, includeUncertainty:Boolean, includeReasons:Boolean) : String = { val inputSubstitutionQuery = input match { - case aliases:JMapWrapper[String,String] => { + case rawAliases:JMapWrapper[_,_] => { + // The following line needed to suppress compiler type erasure warning + val aliases = rawAliases.asInstanceOf[JMapWrapper[String,String]] aliases.map{ case (vizierName, mimirName) => db.sql.registerVizierNameMapping(vizierName.toUpperCase(), mimirName) } query } - case inputs:Seq[String] => { + case rawInputs:Seq[_] => { + // The following line needed to suppress compiler type erasure warning + val inputs = rawInputs.asInstanceOf[Seq[String]] inputs.zipWithIndex.foldLeft(query)((init, curr) => { init.replaceAll(s"\\{\\{\\s*input_${curr._2}\\s*\\}\\}", curr._1) }) @@ -732,7 +741,7 @@ def vistrailsQueryMimirJson(query : String, includeUncertainty:Boolean, includeR try{ logger.debug("getSchema: From Vistrails: [" + query + "]" ) ; val oper = totallyOptimize(db.sql.convert(db.parse(query).head.asInstanceOf[Select])) - JSONBuilder.list( db.typechecker.schemaOf(oper).map( schel => Map( "name" -> schel._1, "type" -> schel._2.toString(), "base_type" -> Type.rootType(schel._2).toString()))) + JSONBuilder.list( db.typechecker.schemaOf(oper).map( schel => Map( "name" -> schel._1, "type" -> schel._2.toString(), "base_type" -> db.types.rootType(schel._2).toString()))) } catch { case t: Throwable => { logger.error("Error Getting Schema: [" + query + "]", t) @@ -1105,7 +1114,7 @@ def vistrailsQueryMimirJson(query : String, includeUncertainty:Boolean, includeR val resultList = results.toList val (resultsStrs, prov) = resultList.map(row => (row.tuple.map(cell => cell), row.provenance.asString)).unzip JSONBuilder.dict(Map( - "schema" -> results.schema.map( schel => Map( "name" -> schel._1, "type" ->schel._2.toString(), "base_type" -> Type.rootType(schel._2).toString())), + "schema" -> results.schema.map( schel => Map( "name" -> schel._1, "type" ->schel._2.toString(), "base_type" -> db.types.rootType(schel._2).toString())), "data" -> resultsStrs, "prov" -> prov )) @@ -1120,7 +1129,7 @@ def vistrailsQueryMimirJson(query : String, includeUncertainty:Boolean, includeR val (resultsStrs, colTaint) = resultsStrsColTaint.unzip val (prov, rowTaint) = provRowTaint.unzip JSONBuilder.dict(Map( - "schema" -> results.schema.map( schel => Map( "name" -> schel._1, "type" ->schel._2.toString(), "base_type" -> Type.rootType(schel._2).toString())), + "schema" -> results.schema.map( schel => Map( "name" -> schel._1, "type" ->schel._2.toString(), "base_type" -> db.types.rootType(schel._2).toString())), "data" -> resultsStrs, "prov" -> prov, "col_taint" -> colTaint, @@ -1138,7 +1147,7 @@ def vistrailsQueryMimirJson(query : String, includeUncertainty:Boolean, includeR val (prov, rowTaint) = provRowTaint.unzip val reasons = explainEverything(oper).map(reasonSet => reasonSet.all(db).toSeq.map(_.toJSONWithFeedback)) JSONBuilder.dict(Map( - "schema" -> results.schema.map( schel => Map( "name" -> schel._1, "type" ->schel._2.toString(), "base_type" -> Type.rootType(schel._2).toString())), + "schema" -> results.schema.map( schel => Map( "name" -> schel._1, "type" ->schel._2.toString(), "base_type" -> db.types.rootType(schel._2).toString())), "data" -> resultsStrs, "prov" -> prov, "col_taint" -> colTaint, diff --git a/src/main/scala/mimir/adaptive/AdaptiveSchemaManager.scala b/src/main/scala/mimir/adaptive/AdaptiveSchemaManager.scala index c8f0c345..82ac143d 100644 --- a/src/main/scala/mimir/adaptive/AdaptiveSchemaManager.scala +++ b/src/main/scala/mimir/adaptive/AdaptiveSchemaManager.scala @@ -63,9 +63,9 @@ class AdaptiveSchemaManager(db: Database) ){ _.map { row => val name = row(0).asString val mlensType = row(1).asString - val query = Json.toOperator(Json.parse(row(2).asString)) + val query = Json.toOperator(Json.parse(row(2).asString), db.types) val args:Seq[Expression] = - Json.toExpressionList(Json.parse(row(3).asString)) + Json.toExpressionList(Json.parse(row(3).asString), db.types) ( MultilensRegistry.multilenses(mlensType), @@ -87,9 +87,9 @@ class AdaptiveSchemaManager(db: Database) ){ _.map { row => val name = row(0).asString val mlensType = row(1).asString - val query = Json.toOperator(Json.parse(row(2).asString)) + val query = Json.toOperator(Json.parse(row(2).asString), db.types) val args:Seq[Expression] = - Json.toExpressionList(Json.parse(row(3).asString)) + Json.toExpressionList(Json.parse(row(3).asString), db.types) ( MultilensRegistry.multilenses(mlensType), @@ -166,9 +166,9 @@ class AdaptiveSchemaManager(db: Database) val row = result.next val name = row(0).asString val mlensType = row(1).asString - val query = Json.toOperator(Json.parse(row(2).asString)) + val query = Json.toOperator(Json.parse(row(2).asString), db.types) val args:Seq[Expression] = - Json.toExpressionList(Json.parse(row(3).asString)) + Json.toExpressionList(Json.parse(row(3).asString), db.types) Some(( MultilensRegistry.multilenses(mlensType), diff --git a/src/main/scala/mimir/adaptive/DiscalaAbadiNormalizer.scala b/src/main/scala/mimir/adaptive/DiscalaAbadiNormalizer.scala index f83f96fe..da5a80f3 100644 --- a/src/main/scala/mimir/adaptive/DiscalaAbadiNormalizer.scala +++ b/src/main/scala/mimir/adaptive/DiscalaAbadiNormalizer.scala @@ -74,8 +74,8 @@ object DiscalaAbadiNormalizer fullSchema.map { x => (x._2.toLong -> x._1._1) }.toMap ) groupingModel.trainDomain(db)//.reconnectToDatabase(db) - val schemaLookup = - fullSchema.map( x => (x._2 -> x._1) ).toMap + val schemaLookup:Map[Int,(String,BaseType)] = + fullSchema.map( x => (x._2 -> (x._1._1, db.types.rootType(x._1._2))) ).toMap // for every possible parent/child relationship except for ROOT val parentKeyRepairs = @@ -261,9 +261,9 @@ class DAFDRepairModel( name: String, context: String, source: Operator, - keys: Seq[(String, Type)], + keys: Seq[(String, BaseType)], target: String, - targetType: Type, + targetType: BaseType, scoreCol: Option[String], attrLookup: Map[Long,String] ) extends RepairKeyModel(name, context, source, keys, target, targetType, scoreCol) diff --git a/src/main/scala/mimir/adaptive/SchemaMatching.scala b/src/main/scala/mimir/adaptive/SchemaMatching.scala index 1dbfb4c8..7c04dc66 100644 --- a/src/main/scala/mimir/adaptive/SchemaMatching.scala +++ b/src/main/scala/mimir/adaptive/SchemaMatching.scala @@ -27,7 +27,7 @@ object SchemaMatching val typeName = split(1) ( varName.toString.toUpperCase -> - Type.fromString(typeName.toString) + db.types.rootType(db.types.fromString(typeName.toString)) ) }). toList @@ -118,7 +118,7 @@ object SchemaMatching val typeName = split(1) ( varName.toString.toUpperCase -> - Type.fromString(typeName.toString) + db.types.fromString(typeName.toString) ) }).toList targetSchema.tail.foldLeft( @@ -149,16 +149,16 @@ object SchemaMatching { if(table.equals("DATA")){ val targetSchema = - config.args. - map(field => { - val split = db.interpreter.evalString(field).split(" +") - val varName = split(0).toUpperCase - val typeName = split(1) - ( - varName.toString.toUpperCase -> - Type.fromString(typeName.toString) - ) - }).toList + config.args. + map(field => { + val split = db.interpreter.evalString(field).split(" +") + val varName = split(0).toUpperCase + val typeName = split(1) + ( + varName.toString.toUpperCase -> + db.types.fromString(typeName.toString) + ) + }).toList Some(Project( targetSchema.map { case (colName, colType) => { val metaModel = db.models.get(s"${config.schema}:META:$colName") diff --git a/src/main/scala/mimir/adaptive/TypeInference.scala b/src/main/scala/mimir/adaptive/TypeInference.scala index be559809..0bfc225f 100644 --- a/src/main/scala/mimir/adaptive/TypeInference.scala +++ b/src/main/scala/mimir/adaptive/TypeInference.scala @@ -15,14 +15,14 @@ object TypeInference { - def detectType(v: String): Iterable[Type] = { - Type.tests.flatMap({ case (t, regexp) => - regexp.findFirstMatchIn(v).map(_ => t) - })++ - TypeRegistry.matchers.flatMap({ case (regexp, name) => - regexp.findFirstMatchIn(v).map(_ => TUser(name)) - }) - } + // def detectType(v: String, types:TypeRegistry): Iterable[Type] = { + // Type.tests.flatMap({ case (t, regexp) => + // regexp.findFirstMatchIn(v).map(_ => t) + // })++ + // TypeRegistry.matchers.flatMap({ case (regexp, name) => + // regexp.findFirstMatchIn(v).map(_ => TUser(name)) + // }) + // } def initSchema(db: Database, config: MultilensConfig): TraversableOnce[Model] = { @@ -49,7 +49,8 @@ object TypeInference modelColumns, stringDefaultScore, db.backend.asInstanceOf[BackendWithSparkContext].getSparkContext(), - Some(db.backend.execute(config.query.limit(TypeInferenceModel.sampleLimit, 0))) + Some(db.backend.execute(config.query.limit(TypeInferenceModel.sampleLimit, 0))), + db.types.getSerializable ) val columnIndexes = @@ -115,14 +116,26 @@ object TypeInference if(table.equals("DATA")){ val model = db.models.get(s"MIMIR_TI_ATTR_${config.schema}").asInstanceOf[TypeInferenceModel] val columnIndexes = model.columns.zipWithIndex.toMap + Some(Project( config.query.columnNames.map { colName => ProjectArg(colName, if(columnIndexes contains colName){ - Function("CAST", Seq( + + val targetType = + model.bestGuess(0, + Seq(IntPrimitive(columnIndexes(colName))), + Seq() + ).asInstanceOf[TypePrimitive].t + + val castExpr = db.types.typeCaster(targetType, Var(colName)) + + Conditional( + IsNullExpression(Var(colName)), Var(colName), - model.bestGuess(0, Seq(IntPrimitive(columnIndexes(colName))), Seq()) - )) + castExpr + ) + } else { Var(colName) } diff --git a/src/main/scala/mimir/algebra/Cast.scala b/src/main/scala/mimir/algebra/Cast.scala index c2589fcb..774c16c9 100644 --- a/src/main/scala/mimir/algebra/Cast.scala +++ b/src/main/scala/mimir/algebra/Cast.scala @@ -4,7 +4,7 @@ import mimir.util._ object Cast { - def apply(t: Type, x: PrimitiveValue): PrimitiveValue = + def apply(t: BaseType, x: PrimitiveValue): PrimitiveValue = { try { t match { @@ -27,18 +27,12 @@ object Cast case _ => TextUtils.parseInterval(x.asString) } case TRowId() => RowIdPrimitive(x.asString) - case TAny() => x case TBool() => BoolPrimitive(x.asLong != 0) - case TType() => TypePrimitive(Type.fromString(x.asString)) - case TUser(name) => { - val (typeRegexp, baseT) = TypeRegistry.registeredTypes(name) - val base = apply(baseT, x) - if(typeRegexp.findFirstMatchIn(base.asString).isEmpty){ - NullPrimitive() - } else { - base - } - } + case TType() => TypePrimitive( + BaseType.fromString(x.asString) + .getOrElse{ TUser(x.asString) } + ) + case TAny() => x } } catch { case _:TypeException=> NullPrimitive(); @@ -46,6 +40,6 @@ object Cast } } - def apply(t: Type, x: String): PrimitiveValue = + def apply(t: BaseType, x: String): PrimitiveValue = apply(t, StringPrimitive(x)) } diff --git a/src/main/scala/mimir/algebra/Eval.scala b/src/main/scala/mimir/algebra/Eval.scala index e5010bd2..172ceff7 100644 --- a/src/main/scala/mimir/algebra/Eval.scala +++ b/src/main/scala/mimir/algebra/Eval.scala @@ -213,12 +213,12 @@ object Eval ): PrimitiveValue = { if(a.equals(NullPrimitive()) || b.equals(NullPrimitive())) { return NullPrimitive() } - val aRoot = Type.rootType(a.getType) - val bRoot = Type.rootType(b.getType) + val typeOfA = a.getType + val typeOfB = b.getType - (op, aRoot, bRoot, + (op, typeOfA, typeOfB, Typechecker.escalate( - aRoot, bRoot, op, "Evaluate Arithmetic", Arithmetic(op, a, b) + typeOfA, typeOfB, op, "Evaluate Arithmetic", Arithmetic(op, a, b) )) match { case (Arith.Add, _, _, TInt()) => IntPrimitive(a.asLong + b.asLong) diff --git a/src/main/scala/mimir/algebra/Expression.scala b/src/main/scala/mimir/algebra/Expression.scala index 00036fa5..031ba514 100644 --- a/src/main/scala/mimir/algebra/Expression.scala +++ b/src/main/scala/mimir/algebra/Expression.scala @@ -1,6 +1,5 @@ package mimir.algebra; -import mimir.algebra.Type._; import org.joda.time.DateTime; import org.joda.time.Period; @@ -318,7 +317,7 @@ case class IsNullExpression(child: Expression) extends Expression { * Slightly more specific base type for constant terms. PrimitiveValue * also acts as a boxing type for constants in Mimir. */ -abstract sealed class PrimitiveValue(t: Type) +abstract sealed class PrimitiveValue(t: BaseType) extends LeafExpression with Serializable { def getType = t @@ -375,7 +374,7 @@ abstract sealed class PrimitiveValue(t: Type) def payload: Object; } -abstract sealed class NumericPrimitive(t: Type) extends PrimitiveValue(t) +abstract sealed class NumericPrimitive(t: BaseType) extends PrimitiveValue(t) /** * Boxed representation of a long integer @@ -590,7 +589,7 @@ case class IntervalPrimitive(p: Period) */ abstract class Proc(val args: Seq[Expression]) extends Expression { - def getType(argTypes: Seq[Type]): Type + def getType(argTypes: Seq[BaseType]): BaseType def getArgs = args def children = args def get(v: Seq[PrimitiveValue]): PrimitiveValue diff --git a/src/main/scala/mimir/algebra/Type.scala b/src/main/scala/mimir/algebra/Type.scala index d0878530..5a87f9d1 100644 --- a/src/main/scala/mimir/algebra/Type.scala +++ b/src/main/scala/mimir/algebra/Type.scala @@ -2,150 +2,114 @@ package mimir.algebra; import scala.collection.mutable.ListBuffer import scala.util.matching.Regex +import mimir.util.SealedSubclassEnumeration /** * An enum class defining the type of primitive-valued expressions * (e.g., integers, floats, strings, etc...) */ sealed trait Type -{ - override def toString(): String = { - Type.toString(this) - } -} -object Type { - val rootTypes = Seq(TInt(), TFloat(), TDate(), TTimestamp(), TString(), TBool(), TInterval()) - - def toString(t:Type) = t match { - case TInt() => "int" - case TFloat() => "real" - case TDate() => "date" - case TTimestamp() => "datetime" - case TString() => "varchar" - case TBool() => "bool" - case TRowId() => "rowid" - case TType() => "type" - case TAny() => "any" - case TUser(name) => name.toLowerCase - case TInterval() => "interval" - } - def fromString(t: String) = t.toLowerCase match { - case "int" => TInt() - case "integer" => TInt() - case "float" => TFloat() - case "double" => TFloat() - case "decimal" => TFloat() - case "real" => TFloat() - case "num" => TFloat() - case "date" => TDate() - case "datetime" => TTimestamp() - case "timestamp" => TTimestamp() - case "interval" => TInterval() - case "varchar" => TString() - case "nvarchar" => TString() - case "char" => TString() - case "string" => TString() - case "text" => TString() - case "bool" => TBool() - case "rowid" => TRowId() - case "type" => TType() - case "any" => TAny() - case "" => TAny() // SQLite doesn't do types sometimes - case x if TypeRegistry.registeredTypes contains x => TUser(x) - case _ => - throw new RAException("Invalid Type '" + t + "'"); - } - def toSQLiteType(i:Int) = i match { - case 0 => TInt() - case 1 => TFloat() - case 2 => TDate() - case 3 => TString() - case 4 => TBool() - case 5 => TRowId() - case 6 => TType() - case 7 => TAny() - case 8 => TTimestamp() - case 9 => TInterval() - case _ => { - // 9 because this is the number of native types, if more are added then this number needs to increase - TUser(TypeRegistry.idxType(i-10)) +object BaseType { + + def toString(t:BaseType) = + t match { + case TInt() => "int" + case TFloat() => "real" + case TDate() => "date" + case TTimestamp() => "datetime" + case TString() => "varchar" + case TBool() => "bool" + case TRowId() => "rowid" + case TType() => "type" + case TInterval() => "interval" + case TAny() => "any" + } + + def fromString(t: String):Option[BaseType] = + t.toLowerCase match { + case "int" => Some(TInt()) + case "integer" => Some(TInt()) + case "float" => Some(TFloat()) + case "double" => Some(TFloat()) + case "decimal" => Some(TFloat()) + case "real" => Some(TFloat() ) + case "num" => Some(TFloat()) + case "date" => Some(TDate()) + case "datetime" => Some(TTimestamp()) + case "timestamp" => Some(TTimestamp()) + case "interval" => Some(TInterval()) + case "varchar" => Some(TString()) + case "nvarchar" => Some(TString()) + case "char" => Some(TString()) + case "string" => Some(TString()) + case "text" => Some(TString()) + case "bool" => Some(TBool()) + case "rowid" => Some(TRowId()) + case "type" => Some(TType()) + case "" => Some(TAny()) + case "any" => Some(TAny()) + case _ => None } - } - def id(t:Type) = t match { - case TInt() => 0 - case TFloat() => 1 - case TDate() => 2 - case TString() => 3 - case TBool() => 4 - case TRowId() => 5 - case TType() => 6 - case TAny() => 7 - case TTimestamp() => 8 - case TInterval() => 9 - case TUser(name) => TypeRegistry.typeIdx(name.toLowerCase)+10 - // 9 because this is the number of native types, if more are added then this number needs to increase - } - val tests = Map[Type,Regex]( + val tests = Seq[(BaseType,Regex)]( TInt() -> "^(\\+|-)?([0-9]+)$".r, TFloat() -> "^(\\+|-)?([0-9]*(\\.[0-9]+)?)$".r, - //TDate() -> "^[0-9]{4}[\\/\\\\-][0-9]{1,2}[\\/\\\\-][0-9]{1,2}$".r, - //TTimestamp() -> """^[0-9]{4}[\/\\-][0-9]{1,2}[\/\\-][0-9]{1,2}\s+[0-9]{1,2}:[0-9]{1,2}:(?:[0-9]{0,2}(?:\.[0-9]*)?)$""".r, TDate() -> "^[0-9]{4}\\-[0-9]{1,2}\\-[0-9]{1,2}$".r, TTimestamp() -> "^[0-9]{4}\\-[0-9]{1,2}\\-[0-9]{1,2}\\s+[0-9]{1,2}:[0-9]{1,2}:([0-9]{1,2}|[0-9]{0,2}\\.[0-9]*)$".r, TBool() -> "^(?i:true|false)$".r ) - def matches(t: Type, v: String): Boolean = - tests.get(t) match { - case Some(test) => !test.findFirstIn(v).isEmpty - case None => true - } - def rootType(t: Type): Type = - t match { - case TUser(t2) => rootType(TypeRegistry.baseType(t2)) - case t2 => t2 - } - - def isNumeric(t: Type, treatTAnyAsNumeric: Boolean = false): Boolean = - { - rootType(t) match { - case TInt() | TFloat() => true - case TAny() => treatTAnyAsNumeric - case _ => false - } + def guessBaseType(s:String): BaseType = { + tests.find { case (_,regex) => (regex findFirstIn s) != None } + .map { case (t,_) => t } + .getOrElse(TString()) } - def getType(s:String): Type = { - var t: Type = TString() // default - - if(s.matches("^(\\+|-)?([0-9]*(\\.[0-9]+)?)$")) // float - t = TFloat() - else if(s.matches("^(\\+|-)?([0-9]+)$")) // Int - t = TInt() - else if(s.matches("^(?i:true|false)$")) // Bool - t = TBool() - else if(s.matches("^[0-9]{4}\\-[0-9]{2}\\-[0-9]{2}$")) - t = TDate() - else if(s.matches("^[0-9]{4}\\-[0-9]{2}\\-[0-9]{2}\\ \\[0-9]{2}\\:[0-9]{2}\\:[0-9]{2}")) - t = TTimestamp() - t + val idTypeOrder = Seq[BaseType]( + TInt(), + TFloat(), + TDate(), + TString(), + TBool(), + TRowId(), + TType(), + TAny(), + TTimestamp(), + TInterval() + ) +} + +sealed trait BaseType extends Type +{ + override def toString(): String = { + BaseType.toString(this) } + val isNumeric = false +} +case class TInt() extends BaseType +{ + override val isNumeric = true } +case class TFloat() extends BaseType +{ + override val isNumeric = true +} + +case class TDate() extends BaseType +case class TString() extends BaseType +case class TBool() extends BaseType +case class TRowId() extends BaseType +case class TType() extends BaseType +case class TTimestamp() extends BaseType +case class TInterval() extends BaseType +case class TAny() extends BaseType -case class TInt() extends Type -case class TFloat() extends Type -case class TDate() extends Type -case class TString() extends Type -case class TBool() extends Type -case class TRowId() extends Type -case class TType() extends Type -case class TAny() extends Type case class TUser(name:String) extends Type -case class TTimestamp() extends Type -case class TInterval() extends Type +{ + override def toString() = name +} @@ -170,29 +134,29 @@ These are the files that need to change to extend the TUser - change sealed trait Type so that all the instances of TUser match the new parameters - mimir.sql.sqlite.SQLiteCompat: - update TUser type parameters - */ -object TypeRegistry { - val registeredTypes = Map[String,(Regex,Type)]( - "tuser" -> ("USER".r, TString()), - "tweight" -> ("KG*".r, TString()), - "productid" -> ("^P\\d+$".r, TString()), - "firecompany" -> ("^[a-zA-Z]\\d{3}$".r, TString()), - "zipcode" -> ("^\\d{5}(?:[-\\s]\\d{4})?$".r, TString()), - "container" -> ("^[A-Z]{4}[0-9]{7}$".r, TString()), - "carriercode" -> ("^[A-Z]{4}$".r, TString()), - "mmsi" -> ("^MID\\d{6}|0MID\\d{5}|00MID\\{4}$".r, TString()), - "billoflanding" -> ("^[A-Z]{8}[0-9]{8}$".r, TString()), - "imo_code" -> ("^\\d{7}$".r, TInt()) - ) - val idxType = registeredTypes.keys.toIndexedSeq - val typeIdx = idxType.zipWithIndex.toMap - - val matchers = registeredTypes.toSeq.map( t => (t._2._1, t._1) ) - - def baseType(t: String): Type = registeredTypes(t)._2 - - def matcher(t: String): Regex = registeredTypes(t)._1 - - def matches(t: String, v: String): Boolean = - !matcher(t).findFirstIn(v).isEmpty -} + */ +// object TypeRegistry { +// val registeredTypes = Map[String,(Regex,Type)]( +// "tuser" -> ("USER".r, TString()), +// "tweight" -> ("KG*".r, TString()), +// "productid" -> ("^P\\d+$".r, TString()), +// "firecompany" -> ("^[a-zA-Z]\\d{3}$".r, TString()), +// "zipcode" -> ("^\\d{5}(?:[-\\s]\\d{4})?$".r, TString()), +// "container" -> ("^[A-Z]{4}[0-9]{7}$".r, TString()), +// "carriercode" -> ("^[A-Z]{4}$".r, TString()), +// "mmsi" -> ("^MID\\d{6}|0MID\\d{5}|00MID\\{4}$".r, TString()), +// "billoflanding" -> ("^[A-Z]{8}[0-9]{8}$".r, TString()), +// "imo_code" -> ("^\\d{7}$".r, TInt()) +// ) +// val idxType = registeredTypes.keys.toIndexedSeq +// val typeIdx = idxType.zipWithIndex.toMap + +// val matchers = registeredTypes.toSeq.map( t => (t._2._1, t._1) ) + +// def baseType(t: String): Type = registeredTypes(t)._2 + +// def matcher(t: String): Regex = registeredTypes(t)._1 + +// def matches(t: String, v: String): Boolean = +// !matcher(t).findFirstIn(v).isEmpty +// } diff --git a/src/main/scala/mimir/algebra/Typechecker.scala b/src/main/scala/mimir/algebra/Typechecker.scala index fe81bdfe..0e1b1345 100644 --- a/src/main/scala/mimir/algebra/Typechecker.scala +++ b/src/main/scala/mimir/algebra/Typechecker.scala @@ -6,32 +6,33 @@ import com.typesafe.scalalogging.slf4j.LazyLogging import mimir.Database import mimir.algebra.function._ +import mimir.algebra.typeregistry._ import mimir.models.{Model, ModelManager} import Arith.{Add, Sub, Mult, Div, And, Or, BitAnd, BitOr, ShiftLeft, ShiftRight} import Cmp.{Gt, Lt, Lte, Gte, Eq, Neq, Like, NotLike} class TypecheckError(msg: String, e: Throwable, context: Option[Operator] = None) - extends Exception(msg, e) + extends Exception(msg, e) { - def errorTypeString = - getClass().getTypeName() + def errorTypeString = + getClass().getTypeName() - override def toString = - context match { - case None => s"$errorTypeString : $msg" - case Some(oper) => s"$errorTypeString : $msg\n$oper" - } + override def toString = + context match { + case None => s"$errorTypeString : $msg" + case Some(oper) => s"$errorTypeString : $msg\n$oper" + } - override def getMessage = - context match { - case None => msg - case Some(oper) => s"$msg in ${oper.toString.filter { _ != '\n' }.take(200)}" - } + override def getMessage = + context match { + case None => msg + case Some(oper) => s"$msg in ${oper.toString.filter { _ != '\n' }.take(200)}" + } } class MissingVariable(varName: String, e: Throwable, context: Option[Operator] = None) - extends TypecheckError(varName, e, context); + extends TypecheckError(varName, e, context); /** * ExpressionChecker wraps around a bit of context that makes @@ -51,116 +52,127 @@ class MissingVariable(varName: String, e: Throwable, context: Option[Operator] = * error with this operator. */ class Typechecker( - functions: Option[FunctionRegistry] = None, - aggregates: Option[AggregateRegistry] = None, - models: Option[ModelManager] = None + functions: Option[FunctionRegistry] = None, + aggregates: Option[AggregateRegistry] = None, + types: TypeRegistry = DefaultTypeRegistry, + models: Option[ModelManager] = None ) extends LazyLogging { - /* Assert that the expressions claimed type is its type */ - def assert(e: Expression, t: Type, scope: (String => Type), context: Option[Operator] = None, msg: String = "Typechecker"): Unit = { - val eType = typeOf(e, scope); - if(!Typechecker.canCoerce(eType, t)){ - logger.trace(s"LUB: ${Typechecker.leastUpperBound(eType, t)}") - throw new TypeException(eType, t, msg, Some(e)) - } - } - - def weakTypeOf(e: Expression) = - typeOf(e, (_) => TAny()) - - def typeOf(e: Expression, o: Operator): Type = - typeOf(e, scope = schemaOf(o).toMap, context = Some(o)) - - def typeOf( - e: Expression, - scope: (String => Type) = { (_:String) => throw new RAException("Need a scope to typecheck expressions with variables") }, - context: Option[Operator] = None - ): Type = { - val recur = typeOf(_:Expression, scope, context) - - e match { - case p: PrimitiveValue => p.getType; - case Not(child) => assert(child, TBool(), scope, context, "NOT"); TBool() - case p: Proc => p.getType(p.children.map(recur(_))) - case Arithmetic(op, lhs, rhs) => - Typechecker.escalate(recur(lhs), recur(rhs), op, "Arithmetic", e) - case Comparison(op, lhs, rhs) => { - op match { - case (Eq | Neq) => - Typechecker.assertLeastUpperBound(recur(lhs), recur(rhs), "Comparison", e) - case (Gt | Gte | Lt | Lte) => - Typechecker.assertOneOf( - Typechecker.assertLeastUpperBound( - recur(lhs), - recur(rhs), - "Comparison", - e - ), - Set(TDate(), TInterval(), TTimestamp(), TInt(), TFloat()), - e - ) - case (Like | NotLike) => - assert(lhs, TString(), scope, context, "LIKE") - assert(rhs, TString(), scope, context, "LIKE") - } - TBool() - } - case Var(name) => - try { - val t = scope(name) - logger.debug(s"Type of $name is $t") - t - } catch { - case x:NoSuchElementException => throw new MissingVariable(name, x, context) - } - case JDBCVar(t) => t - case Function("CAST", fargs) => - // Special case CAST - fargs(1) match { - case TypePrimitive(t) => t - case p:PrimitiveValue => { p match { - case StringPrimitive(s) => Type.toSQLiteType(Integer.parseInt(s)) - case IntPrimitive(i) => Type.toSQLiteType(i.toInt) - case _ => throw new RAException("Invalid CAST to '"+p+"' of type: "+recur(p)) + /* Assert that the expressions claimed type is its type */ + def assert(e: Expression, t: Type, scope: (String => Type), context: Option[Operator] = None, msg: String = "Typechecker"): Unit = { + val eType = typeOf(e, scope); + if(!Typechecker.canCoerce(eType, t, types)){ + logger.trace(s"LUB: ${Typechecker.leastUpperBound(eType, t, types)}") + throw new TypeException(eType, t, msg, Some(e)) + } + } + + def weakTypeOf(e: Expression) = + typeOf(e, (_) => TAny()) + + def typeOf(e: Expression, o: Operator): Type = + typeOf(e, scope = schemaOf(o).toMap, context = Some(o)) + + def rootType = types.rootType _ + + def typeOf( + e: Expression, + scope: (String => Type) = { (_:String) => throw new RAException("Need a scope to typecheck expressions with variables") }, + context: Option[Operator] = None + ): Type = { + val recur = typeOf(_:Expression, scope, context) + + e match { + case p: PrimitiveValue => p.getType; + case Not(child) => assert(child, TBool(), scope, context, "NOT"); TBool() + case p: Proc => p.getType(p.children.map { recur(_) } + .map { types.rootType(_) } + ) + case Arithmetic(op, lhs, rhs) => + Typechecker.escalate(types.rootType(recur(lhs)), types.rootType(recur(rhs)), op, "Arithmetic", e) + case Comparison(op, lhs, rhs) => { + op match { + case (Eq | Neq) => + Typechecker.assertLeastUpperBound(recur(lhs), recur(rhs), "Comparison", e, types) + case (Gt | Gte | Lt | Lte) => + Typechecker.assertOneOf( + types.rootType( + Typechecker.assertLeastUpperBound( + recur(lhs), + recur(rhs), + "Comparison", + e, + types + ) + ), + Set(TDate(), TInterval(), TTimestamp(), TInt(), TFloat()), + e + ) + case (Like | NotLike) => + assert(lhs, TString(), scope, context, "LIKE") + assert(rhs, TString(), scope, context, "LIKE") + } + TBool() + } + case Var(name) => + try { + val t = scope(name) + logger.debug(s"Type of $name is $t") + t + } catch { + case x:NoSuchElementException => throw new MissingVariable(name, x, context) + } + case JDBCVar(t) => t + case Function("CAST", fargs) => + // Special case CAST + fargs(1) match { + case TypePrimitive(t) => t + case p:PrimitiveValue => { p match { + case StringPrimitive(s) => types.typeForId(Integer.parseInt(s)) + case IntPrimitive(i) => types.typeForId(i.toInt) + case _ => throw new RAException("Invalid CAST to '"+p+"' of type: "+recur(p)) } - } - case _ => TAny() - } - case Function(name, args) => - returnTypeOfFunction(name, args.map { recur(_) }) - - case Conditional(condition, thenClause, elseClause) => - assert(condition, TBool(), scope, context, "WHEN") - Typechecker.assertLeastUpperBound( - recur(elseClause), - recur(thenClause), - "CASE-WHEN", - e - ) - case IsNullExpression(child) => - recur(child); - TBool() - case RowIdVar() => TRowId() - case VGTerm(model, idx, args, hints) => - models match { - case Some(registry) => - registry.get(model).varType(idx, args.map(recur(_))) - case None => throw new RAException("Need Model Manager to typecheck expressions with VGTerms") - } + } + case _ => TAny() + } + case Function(name, args) => + returnTypeOfFunction(name, args.map { recur(_) }.map { types.rootType(_) }) + + case Conditional(condition, thenClause, elseClause) => + assert(condition, TBool(), scope, context, "WHEN") + Typechecker.assertLeastUpperBound( + recur(elseClause), + recur(thenClause), + "CASE-WHEN", + e, + types + ) + case IsNullExpression(child) => + recur(child); + TBool() + case RowIdVar() => TRowId() + case VGTerm(model, idx, args, hints) => + models match { + case Some(registry) => + registry.get(model).varType(idx, + args.map { recur(_) }.map { types.rootType(_) } + ) + case None => throw new RAException("Need Model Manager to typecheck expressions with VGTerms") + } } } - def returnTypeOfFunction(name: String, args: Seq[Type]): Type = { + def returnTypeOfFunction(name: String, args: Seq[BaseType]): BaseType = { try { functions.flatMap { _.getOption(name) } match { case Some(NativeFunction(_, _, getType, _)) => getType(args) case Some(ExpressionFunction(_, argNames, expr)) => - typeOf(expr, scope = argNames.zip(args).toMap) + types.rootType(typeOf(expr, scope = argNames.zip(args).toMap)) case Some(FoldFunction(_, expr)) => args.tail.foldLeft(args.head){ case (curr,next) => - typeOf(expr, Map("CURR" -> curr, "NEXT" -> next)) } + types.rootType(typeOf(expr, Map("CURR" -> curr, "NEXT" -> next))) } case None => - throw new RAException(s"Function $name(${args.mkString(",")}) is undefined") + throw new RAException(s"Function $name(${args.mkString(",")}) is undefined") } } catch { case TypeException(found, expected, detail, None) => @@ -168,195 +180,233 @@ class Typechecker( } } - - def schemaOf(o: Operator): Seq[(String, Type)] = - { - o match { - case Project(cols, src) => - val schema = schemaOf(src).toMap - cols.map( { - case ProjectArg(col, expression) => - (col, typeOf(expression, scope = schema(_), context = Some(src))) - }) - - case ProvenanceOf(psel) => - // Not 100% sure this is kosher... doesn't ProvenanceOf introduce new columns? + def schemaOf(o: Operator): Seq[(String, Type)] = + { + o match { + case Project(cols, src) => + val schema = schemaOf(src).toMap + cols.map( { + case ProjectArg(col, expression) => + (col, typeOf(expression, scope = schema(_), context = Some(src))) + }) + + case ProvenanceOf(psel) => + // Not 100% sure this is kosher... doesn't ProvenanceOf introduce new columns? schemaOf(psel) - case Annotate(subj,invisScm) => { + case Annotate(subj,invisScm) => { schemaOf(subj) } - case Recover(subj,invisScm) => { + case Recover(subj,invisScm) => { schemaOf(subj).union(invisScm.map(_._2).map(pasct => (pasct.name,pasct.typ))) } - + case Select(cond, src) => { - val srcSchema = schemaOf(src); - assert(cond, TBool(), srcSchema.toMap, Some(src), "SELECT") - return srcSchema + val srcSchema = schemaOf(src); + assert(cond, TBool(), srcSchema.toMap, Some(src), "SELECT") + return srcSchema } - case Aggregate(gbCols, aggCols, src) => - aggregates match { - case None => throw new RAException("Need Aggregate Registry to typecheck aggregates") - case Some(registry) => { + case Aggregate(gbCols, aggCols, src) => + aggregates match { + case None => throw new RAException("Need Aggregate Registry to typecheck aggregates") + case Some(registry) => { - /* Nested Typechecker */ - val srcSchema = schemaOf(src).toMap - val chk = typeOf(_:Expression, scope = srcSchema, context = Some(src)) + /* Nested Typechecker */ + val srcSchema = schemaOf(src).toMap + val chk = typeOf(_:Expression, scope = srcSchema, context = Some(src)) - /* Get Group By Args and verify type */ - val groupBySchema: Seq[(String, Type)] = gbCols.map(x => (x.toString, chk(x))) + /* Get Group By Args and verify type */ + val groupBySchema: Seq[(String, Type)] = gbCols.map(x => (x.toString, chk(x))) - /* Get function name, check for AVG *//* Get function parameters, verify type */ - val aggSchema: Seq[(String, Type)] = aggCols.map(x => - ( - x.alias, - registry.typecheck(x.function, x.args.map(chk(_))) - ) - ) + /* Get function name, check for AVG *//* Get function parameters, verify type */ + val aggSchema: Seq[(String, Type)] = aggCols.map(x => + ( + x.alias, + registry.typecheck(x.function, x.args.map { chk(_) }.map { types.rootType(_) }) + ) + ) - /* Send schema to parent operator */ - val sch = groupBySchema ++ aggSchema - //println(sch) - sch + /* Send schema to parent operator */ + val sch = groupBySchema ++ aggSchema + //println(sch) + sch - } + } - } + } - case Join(lhs, rhs) => - val lSchema = schemaOf(lhs); - val rSchema = schemaOf(rhs); + case Join(lhs, rhs) => + val lSchema = schemaOf(lhs); + val rSchema = schemaOf(rhs); - val overlap = lSchema.map(_._1).toSet & rSchema.map(_._1).toSet - if(!(overlap.isEmpty)){ - throw new RAException("Ambiguous Keys ('"+overlap+"') in Cross Product\n"+o); - } - lSchema ++ rSchema + val overlap = lSchema.map(_._1).toSet & rSchema.map(_._1).toSet + if(!(overlap.isEmpty)){ + throw new RAException("Ambiguous Keys ('"+overlap+"') in Cross Product\n"+o); + } + lSchema ++ rSchema - case LeftOuterJoin(lhs, rhs, condition) => - schemaOf(Select(condition, Join(lhs, rhs))) + case LeftOuterJoin(lhs, rhs, condition) => + schemaOf(Select(condition, Join(lhs, rhs))) - case Union(lhs, rhs) => - val lSchema = schemaOf(lhs); - val rSchema = schemaOf(rhs); + case Union(lhs, rhs) => + val lSchema = schemaOf(lhs); + val rSchema = schemaOf(rhs); - if(!(lSchema.map(_._1).toSet.equals(rSchema.map(_._1).toSet))){ - throw new RAException("Schema Mismatch in Union\n"+o+s"$lSchema <> $rSchema"); - } - lSchema + if(!(lSchema.map(_._1).toSet.equals(rSchema.map(_._1).toSet))){ + throw new RAException("Schema Mismatch in Union\n"+o+s"$lSchema <> $rSchema"); + } + lSchema - case Table(_, _, sch, meta) => (sch ++ meta.map( x => (x._1, x._3) )) + case Table(_, _, sch, meta) => (sch ++ meta.map( x => (x._1, x._3) )) - case View(_, query, _) => schemaOf(query) - case AdaptiveView(_, _, query, _) => schemaOf(query) + case View(_, query, _) => schemaOf(query) + case AdaptiveView(_, _, query, _) => schemaOf(query) - case HardTable(sch,_) => sch + case HardTable(sch,_) => sch - case Limit(_, _, src) => schemaOf(src) + case Limit(_, _, src) => schemaOf(src) - case Sort(_, src) => schemaOf(src) - } - } + case Sort(_, src) => schemaOf(src) + } + } + + def baseSchemaOf(o: Operator): Seq[(String, BaseType)] = + schemaOf(o).map { case (name, t) => (name, types.rootType(t)) } } object Typechecker extends LazyLogging { - val trivialTypechecker = new Typechecker() - - def assertNumeric(t: Type, e: Expression): Type = - { - if(!Type.isNumeric(t)){ - throw new TypeException(t, TFloat(), "Numeric", Some(e)) - } - t; - } - - def canCoerce(from: Type, to: Type): Boolean = - { - logger.debug("Coerce from $from to $to") - leastUpperBound(from, to) match { - case Some(lub) => lub.equals(to) - case None => false - } - } - - def leastUpperBound(a: Type, b: Type): Option[Type] = - { - if(a.equals(b)){ return Some(a); } - (a, b) match { - case (TAny(), _) => Some(b) - case (_, TAny()) => Some(a) - case (TInt(), TFloat()) => Some(TFloat()) - case (TFloat(), TInt()) => Some(TFloat()) - case (TDate(), TTimestamp()) => Some(TTimestamp()) - case (TTimestamp(), TDate()) => Some(TTimestamp()) - case (TRowId(), TString()) => Some(TRowId()) - case (TString(), TRowId()) => Some(TRowId()) - case (TRowId(), TInt()) => Some(TInt()) - case (TInt(), TRowId()) => Some(TInt()) - case (TUser(name), _) => leastUpperBound(TypeRegistry.baseType(name), b) - case (_, TUser(name)) => leastUpperBound(a, TypeRegistry.baseType(name)) - case _ => return None - } - } - - def leastUpperBound(tl: TraversableOnce[Type]): Option[Type] = - { - tl.map { Some(_) }.fold(Some(TAny()):Option[Type]) { case (Some(a), Some(b)) => leastUpperBound(a, b) case _ => None } - } - - def assertLeastUpperBound(a: Type, b: Type, msg: String, e: Expression): Type = - { - leastUpperBound(a, b) match { - case Some(t) => t - case None => throw new TypeException(a, b, msg, Some(e)) - } - } - def assertLeastUpperBound(tl: TraversableOnce[Type], msg: String, e: Expression): Type = - { - tl.fold(TAny()) { assertLeastUpperBound(_, _, msg, e) } - } - - def assertOneOf(a: Type, candidates: TraversableOnce[Type], e: Expression): Type = - { - candidates.flatMap { leastUpperBound(a, _) }.collectFirst { case x => x } match { - case Some(t) => t - case None => - throw new TypeException(a, TAny(), s"Not one of $candidates", Some(e)) - } - } - - def escalate(a: Type, b: Type, op: Arith.Op, msg: String, e: Expression): Type = - { - escalate(a, b, op) match { - case Some(t) => t - case None => throw new TypeException(a, b, msg, Some(e)); - } - } - def escalate(a: Type, b: Type, op: Arith.Op): Option[Type] = - { - // Start with special case overrides - (a, b, op) match { - case (TDate() | TTimestamp(), - TDate() | TTimestamp(), - Arith.Add) => return Some(TDate()) + val trivialTypechecker = new Typechecker() + + def assertNumeric(t: BaseType, e: Expression): BaseType = + { + if(!t.isNumeric){ + throw new TypeException(t, TFloat(), "Numeric", Some(e)) + } + t; + } + + def canCoerce(from: BaseType, to: BaseType): Boolean = + { + logger.debug("Coerce from $from to $to") + leastUpperBound(from, to) match { + case Some(lub) => lub.equals(to) + case None => false + } + } + + def canCoerce(from: Type, to: Type, types: TypeRegistry): Boolean = + { + logger.debug("Coerce from $from to $to") + leastUpperBound(from, to, types) match { + case Some(lub) => lub.equals(to) + case None => false + } + } + + def leastUpperBound(a: BaseType, b: BaseType): Option[BaseType] = + { + if(a.equals(b)){ return Some(a); } + (a, b) match { + case (TAny(), _) => Some(b) + case (_, TAny()) => Some(a) + case (TInt(), TFloat()) => Some(TFloat()) + case (TFloat(), TInt()) => Some(TFloat()) + case (TDate(), TTimestamp()) => Some(TTimestamp()) + case (TTimestamp(), TDate()) => Some(TTimestamp()) + case (TRowId(), TString()) => Some(TRowId()) + case (TString(), TRowId()) => Some(TRowId()) + case (TRowId(), TInt()) => Some(TInt()) + case (TInt(), TRowId()) => Some(TInt()) + case _ => return None + } + } + + def leastUpperBound(tl: TraversableOnce[BaseType]): Option[BaseType] = + { + tl.map { Some(_) }.fold(Some(TAny()):Option[BaseType]) { case (Some(a), Some(b)) => leastUpperBound(a, b) case _ => None } + } + + def assertLeastUpperBound(a: BaseType, b: BaseType, msg: String, e: Expression): BaseType = + { + leastUpperBound(a, b) match { + case Some(t) => t + case None => throw new TypeException(a, b, msg, Some(e)) + } + } + def assertLeastUpperBound(tl: TraversableOnce[BaseType], msg: String, e: Expression): BaseType = + { + tl.fold(TAny()) { assertLeastUpperBound(_, _, msg, e) } + } + + def assertOneOf(a: BaseType, candidates: TraversableOnce[BaseType], e: Expression): BaseType = + { + candidates.flatMap { leastUpperBound(a, _) }.collectFirst { case x => x } match { + case Some(t) => t + case None => + throw new TypeException(a, TAny(), s"Not one of $candidates", Some(e)) + } + } + + def leastUpperBound(a: Type, b: Type, types:TypeRegistry): Option[Type] = + { + (a, b) match { + case (TUser(_), TUser(_)) => leastUpperBound(types.parentType(a), b, types) + .map { Some(_) } + .getOrElse{ leastUpperBound(a, types.parentType(b), types) } + case (TUser(_), _:BaseType) => leastUpperBound(types.parentType(a), b, types) + case (_:BaseType, TUser(_)) => leastUpperBound(a, types.parentType(b), types) + case (aBase:BaseType, bBase:BaseType) => leastUpperBound(aBase, bBase) + } + } + + def leastUpperBound(tl: TraversableOnce[Type], types:TypeRegistry): Option[Type] = + { + tl.map { Some(_) }.fold(Some(TAny()):Option[Type]) { case (Some(a), Some(b)) => leastUpperBound(a, b, types) case _ => None } + } + + def assertLeastUpperBound(a: Type, b: Type, msg: String, e: Expression, types:TypeRegistry): Type = + { + leastUpperBound(a, b, types) match { + case Some(t) => t + case None => throw new TypeException(a, b, msg, Some(e)) + } + } + def assertLeastUpperBound(tl: TraversableOnce[Type], msg: String, e: Expression, types:TypeRegistry): Type = + { + tl.fold(TAny()) { assertLeastUpperBound(_, _, msg, e, types) } + } + + def escalate(a: BaseType, b: BaseType, op: Arith.Op, msg: String, e: Expression): BaseType = + { + escalate(a, b, op) match { + case Some(t) => t + case None => throw new TypeException(a, b, msg, Some(e)); + } + } + def escalate(a: BaseType, b: BaseType, op: Arith.Op): Option[BaseType] = + { + // Start with special case overrides + (a, b, op) match { + case (TDate() | TTimestamp(), + TDate() | TTimestamp(), + Arith.Add) => return Some(TDate()) // Interval Arithmetic - case (TDate() | TTimestamp(), - TDate() | TTimestamp(), - Arith.Sub) => return Some(TInterval()) - case (TDate() | TTimestamp() | TInterval(), - TInterval(), - Arith.Sub | Arith.Add) => return Some(a) - case (TInt() | TFloat(), TInterval(), Arith.Mult) | - (TInterval(), TInt() | TFloat(), Arith.Mult | Arith.Div) - => return Some(TInterval()) - case (TInterval(), TInterval(), Arith.Div) - => return Some(TFloat()) + case (TDate() | TTimestamp(), + TDate() | TTimestamp(), + Arith.Sub) => return Some(TInterval()) + case (TDate() | TTimestamp() | TInterval(), + TInterval(), + Arith.Sub | Arith.Add) => return Some(a) + case (TInt() | TFloat(), TInterval(), Arith.Mult) | + (TInterval(), TInt() | TFloat(), Arith.Mult | Arith.Div) + => return Some(TInterval()) + case (TInterval(), TInterval(), Arith.Div) + => return Some(TFloat()) // TAny() cases case (TAny(), TAny(), _) => return Some(TAny()) @@ -364,45 +414,45 @@ object Typechecker Arith.Sub) => Some(TInterval()) case (TDate() | TTimestamp(), TAny(), Arith.Sub) => Some(TAny()) // Either TInterval or TDate, depending - case _ => () - } - - (op) match { - case (Arith.Add | Arith.Sub | Arith.Mult | Arith.Div) => - if(Type.isNumeric(a, treatTAnyAsNumeric = true) && Type.isNumeric(b, treatTAnyAsNumeric = true)){ - leastUpperBound(a, b) - } else { + case _ => () + } + + (op) match { + case (Arith.Add | Arith.Sub | Arith.Mult | Arith.Div) => + if( (a.isNumeric || a == TAny()) && (b.isNumeric || b == TAny()) ){ + leastUpperBound(a, b) + } else { None - } + } case (Arith.BitAnd | Arith.BitOr | Arith.ShiftLeft | Arith.ShiftRight) => - (Type.rootType(a), Type.rootType(b)) match { + (a, b) match { case (TInt() | TAny(), TInt() | TAny()) => Some(TInt()) case _ => None } case (Arith.And | Arith.Or) => - (Type.rootType(a), Type.rootType(b)) match { + (a, b) match { case (TBool() | TAny(), TBool() | TAny()) => Some(TBool()) case _ => None } - } - } - def escalate(a: Option[Type], b: Option[Type], op: Arith.Op): Option[Type] = - { - (a, b) match { - case (None,_) => b - case (_,None) => a - case (Some(at), Some(bt)) => escalate(at, bt, op) - } - } - - def escalate(l: TraversableOnce[Type], op: Arith.Op): Option[Type] = - { - l.map(Some(_)).fold(None)(escalate(_,_,op)) - } - def escalate(l: TraversableOnce[Type], op: Arith.Op, msg: String, e: Expression): Type = - { - l.fold(TAny())(escalate(_,_,op,msg,e)) - } + } + } + def escalate(a: Option[BaseType], b: Option[BaseType], op: Arith.Op): Option[BaseType] = + { + (a, b) match { + case (None,_) => b + case (_,None) => a + case (Some(at), Some(bt)) => escalate(at, bt, op) + } + } + + def escalate(l: TraversableOnce[BaseType], op: Arith.Op): Option[BaseType] = + { + l.map(Some(_)).fold(None)(escalate(_,_,op)) + } + def escalate(l: TraversableOnce[BaseType], op: Arith.Op, msg: String, e: Expression): BaseType = + { + l.fold(TAny())(escalate(_,_,op,msg,e)) + } } diff --git a/src/main/scala/mimir/algebra/function/AggregateRegistry.scala b/src/main/scala/mimir/algebra/function/AggregateRegistry.scala index 5445fc32..6ca0c1a1 100644 --- a/src/main/scala/mimir/algebra/function/AggregateRegistry.scala +++ b/src/main/scala/mimir/algebra/function/AggregateRegistry.scala @@ -5,10 +5,10 @@ import mimir.algebra._ case class RegisteredAggregate( aggName: String, - typechecker: (Seq[Type] => Type), + typechecker: (Seq[BaseType] => BaseType), defaultValue: PrimitiveValue ){ - def typecheck(args: Seq[Type]) = typechecker(args) + def typecheck(args: Seq[BaseType]) = typechecker(args) } class AggregateRegistry @@ -28,14 +28,14 @@ class AggregateRegistry register("GROUP_BITWISE_AND", List(TInt()), TInt(), IntPrimitive(Long.MaxValue)) register("GROUP_BITWISE_OR", List(TInt()), TInt(), IntPrimitive(0)) register("JSON_GROUP_ARRAY", (t) => TString(), StringPrimitive("[]")) - register("FIRST", (t:Seq[Type]) => t.head, NullPrimitive()) - register("FIRST_FLOAT", (t:Seq[Type]) => t.head, NullPrimitive()) - register("FIRST_INT", (t:Seq[Type]) => t.head, NullPrimitive()) + register("FIRST", (t:Seq[BaseType]) => t.head, NullPrimitive()) + register("FIRST_FLOAT", (t:Seq[BaseType]) => t.head, NullPrimitive()) + register("FIRST_INT", (t:Seq[BaseType]) => t.head, NullPrimitive()) } def register( aggName: String, - typechecker: Seq[Type] => Type, + typechecker: Seq[BaseType] => BaseType, defaultValue: PrimitiveValue ): Unit = { prototypes.put(aggName, RegisteredAggregate(aggName, typechecker, defaultValue)) @@ -58,8 +58,8 @@ class AggregateRegistry def register( aggName: String, - argTypes: Seq[Type], - retType: Type, + argTypes: Seq[BaseType], + retType: BaseType, defaultValue: PrimitiveValue ): Unit = { register( @@ -79,7 +79,7 @@ class AggregateRegistry ) } - def typecheck(aggName: String, args: Seq[Type]): Type = + def typecheck(aggName: String, args: Seq[BaseType]): BaseType = prototypes(aggName).typecheck(args) def isAggregate(aggName: String): Boolean = diff --git a/src/main/scala/mimir/algebra/function/FunctionRegistry.scala b/src/main/scala/mimir/algebra/function/FunctionRegistry.scala index 873dbcb2..d9c06e1a 100644 --- a/src/main/scala/mimir/algebra/function/FunctionRegistry.scala +++ b/src/main/scala/mimir/algebra/function/FunctionRegistry.scala @@ -12,7 +12,7 @@ sealed abstract class RegisteredFunction { val name: String } case class NativeFunction( name: String, evaluator: Seq[PrimitiveValue] => PrimitiveValue, - typechecker: Seq[Type] => Type, + typechecker: Seq[BaseType] => BaseType, passthrough:Boolean = false ) extends RegisteredFunction @@ -49,14 +49,14 @@ class FunctionRegistry { def register( fname:String, eval:Seq[PrimitiveValue] => PrimitiveValue, - typechecker: Seq[Type] => Type + typechecker: Seq[BaseType] => BaseType ): Unit = register(new NativeFunction(fname, eval, typechecker)) def registerPassthrough( fname:String, eval:Seq[PrimitiveValue] => PrimitiveValue, - typechecker: Seq[Type] => Type + typechecker: Seq[BaseType] => BaseType ): Unit = register(new NativeFunction(fname, eval, typechecker, true)) diff --git a/src/main/scala/mimir/algebra/function/GeoFunctions.scala b/src/main/scala/mimir/algebra/function/GeoFunctions.scala index 6009806d..1b1a8b75 100644 --- a/src/main/scala/mimir/algebra/function/GeoFunctions.scala +++ b/src/main/scala/mimir/algebra/function/GeoFunctions.scala @@ -26,7 +26,7 @@ object GeoFunctions args(3).asDouble //lat2 )) }, - (args) => { + (args:Seq[BaseType]) => { (0 until 4).foreach { i => Typechecker.assertNumeric(args(i), Function("DST", List())) }; TFloat() } diff --git a/src/main/scala/mimir/algebra/function/JsonFunctions.scala b/src/main/scala/mimir/algebra/function/JsonFunctions.scala index 30f26c3d..23d8efc9 100644 --- a/src/main/scala/mimir/algebra/function/JsonFunctions.scala +++ b/src/main/scala/mimir/algebra/function/JsonFunctions.scala @@ -14,7 +14,7 @@ object JsonFunctions } } - def extract(args: Seq[PrimitiveValue], t: Type): PrimitiveValue = + def extract(args: Seq[PrimitiveValue], t: BaseType): PrimitiveValue = { Json.toPrimitive(t, extract(args)) } diff --git a/src/main/scala/mimir/algebra/function/NumericFunctions.scala b/src/main/scala/mimir/algebra/function/NumericFunctions.scala index e27501a3..c28e04fd 100644 --- a/src/main/scala/mimir/algebra/function/NumericFunctions.scala +++ b/src/main/scala/mimir/algebra/function/NumericFunctions.scala @@ -14,14 +14,14 @@ object NumericFunctions case Seq(NullPrimitive()) => NullPrimitive() case x => throw new RAException("Non-numeric parameter to absolute: '"+x+"'") }, - (x: Seq[Type]) => Typechecker.assertNumeric(x(0), Function("ABSOLUTE", List())) + (x: Seq[BaseType]) => Typechecker.assertNumeric(x(0), Function("ABSOLUTE", List())) ) fr.register("SQRT", { case Seq(n:NumericPrimitive) => FloatPrimitive(Math.sqrt(n.asDouble)) }, - (x: Seq[Type]) => Typechecker.assertNumeric(x(0), Function("SQRT", List())) + (x: Seq[BaseType]) => Typechecker.assertNumeric(x(0), Function("SQRT", List())) ) fr.register("BITWISE_AND", @@ -38,17 +38,17 @@ object NumericFunctions fr.register("STDDEV",(_) => ???, (_) => TFloat()) fr.register("min", { - case ints:Seq[IntPrimitive] => IntPrimitive(ints.foldLeft(ints.head.v)( (init, intval) => Math.min(init, intval.v))) + (ints:Seq[PrimitiveValue]) => IntPrimitive(ints.foldLeft(ints.head.asInt)( (init, intval) => Math.min(init, intval.asInt))) }, (_) => TInt()) fr.register("max",{ - case ints:Seq[IntPrimitive] => IntPrimitive(ints.foldLeft(ints.head.v)( (init, intval) => Math.max(init, intval.v))) + (ints:Seq[PrimitiveValue]) => IntPrimitive(ints.foldLeft(ints.head.asInt)( (init, intval) => Math.max(init, intval.asInt))) }, (_) => TInt()) fr.register("MIN", { - case ints:Seq[IntPrimitive] => IntPrimitive(ints.foldLeft(ints.head.v)( (init, intval) => Math.min(init, intval.v))) + (ints:Seq[PrimitiveValue]) => IntPrimitive(ints.foldLeft(ints.head.asInt)( (init, intval) => Math.min(init, intval.asInt))) }, (_) => TInt()) fr.register("MAX",{ - case ints:Seq[IntPrimitive] => IntPrimitive(ints.foldLeft(ints.head.v)( (init, intval) => Math.max(init, intval.v))) + (ints:Seq[PrimitiveValue]) => IntPrimitive(ints.foldLeft(ints.head.asInt)( (init, intval) => Math.max(init, intval.asInt))) }, (_) => TInt()) fr.register("ROUND", @@ -60,7 +60,7 @@ object NumericFunctions FloatPrimitive(Math.round(number).toDouble) } }, - (x: Seq[Type]) => TFloat() + (x: Seq[BaseType]) => TFloat() ) fr.register("ABS", diff --git a/src/main/scala/mimir/algebra/function/SampleFunctions.scala b/src/main/scala/mimir/algebra/function/SampleFunctions.scala index 684f1851..705940ad 100644 --- a/src/main/scala/mimir/algebra/function/SampleFunctions.scala +++ b/src/main/scala/mimir/algebra/function/SampleFunctions.scala @@ -21,7 +21,7 @@ object SampleFunctions case None => NullPrimitive() } }, - (types: Seq[Type]) => { + (types: Seq[BaseType]) => { val debugExpr = Function("BEST_SAMPLE", types.map(TypePrimitive(_))) Typechecker.assertNumeric(types.head, debugExpr) @@ -42,7 +42,7 @@ object SampleFunctions FloatPrimitive( WorldBits.confidence(args(0).asLong, args(0).asLong.toInt) ), - (types: Seq[Type]) => { + (types: Seq[BaseType]) => { Typechecker.assertNumeric(types(0), Function("SAMPLE_CONFIDENCE", types.map(TypePrimitive(_)))) Typechecker.assertNumeric(types(1), diff --git a/src/main/scala/mimir/algebra/function/TypeFunctions.scala b/src/main/scala/mimir/algebra/function/TypeFunctions.scala index b8b7c3a2..f0aa10a0 100644 --- a/src/main/scala/mimir/algebra/function/TypeFunctions.scala +++ b/src/main/scala/mimir/algebra/function/TypeFunctions.scala @@ -24,7 +24,8 @@ object TypeFunctions functionName, (params: Seq[PrimitiveValue]) => { params match { - case Seq(x, TypePrimitive(t)) => Cast(t, x) + case Seq(x, TypePrimitive(t:BaseType)) => Cast(t, x) + case Seq(x, TypePrimitive(t:TUser)) => throw new RAException("Casting to User-defined types unsupported") case _ => throw new RAException("Invalid cast: "+params) } }, diff --git a/src/main/scala/mimir/algebra/gprom/OperatorTranslation.scala b/src/main/scala/mimir/algebra/gprom/OperatorTranslation.scala index c0d64378..723b9fc0 100644 --- a/src/main/scala/mimir/algebra/gprom/OperatorTranslation.scala +++ b/src/main/scala/mimir/algebra/gprom/OperatorTranslation.scala @@ -30,13 +30,8 @@ object ProjectionArgVisibility extends Enumeration { val Invisible = Value("Invisible") } -object OperatorTranslation extends LazyLogging { - - var db: mimir.Database = null - def apply(db: mimir.Database) = { - this.db = db - } - +class OperatorTranslation(db: mimir.Database) extends LazyLogging { + def gpromStructureToMimirOperator(depth : Int, gpromStruct: GProMStructure, gpromParentStruct: GProMStructure ) : Operator = { (gpromStruct match { case list:GProMList => { @@ -303,7 +298,7 @@ object OperatorTranslation extends LazyLogging { case "CAST" => { val castArgs = gpromListToScalaList(functionCall.args).map( gpromParam => translateGProMExpressionToMimirExpression(ctxOpers, gpromParam)) val fixedType = castArgs.last match { - case IntPrimitive(i) => TypePrimitive(Type.toSQLiteType(i.toInt)) + case IntPrimitive(i) => TypePrimitive(db.types.typeForId(i.toInt)) case TypePrimitive(t) => TypePrimitive(t) case x => x } @@ -882,7 +877,7 @@ object OperatorTranslation extends LazyLogging { (GProM_JNA.GProMDataType.GProM_DT_STRING,strPtr,0) } case TypePrimitive(t) => { - val v = Type.id(t) + val v = db.types.idForType(t) val intPtr = new Memory(Native.getNativeSize(classOf[Int])) intPtr.setInt(0, v.asInstanceOf[Int]); (GProM_JNA.GProMDataType.GProM_DT_INT,intPtr,0) diff --git a/src/main/scala/mimir/algebra/spark/OperatorTranslation.scala b/src/main/scala/mimir/algebra/spark/OperatorTranslation.scala index 0f12a195..900698df 100644 --- a/src/main/scala/mimir/algebra/spark/OperatorTranslation.scala +++ b/src/main/scala/mimir/algebra/spark/OperatorTranslation.scala @@ -14,7 +14,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, Ca import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, UnresolvedInlineTable, UnresolvedAttribute} import org.apache.spark.sql.catalyst.plans.{JoinType, Inner,LeftOuter, Cross} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,AggregateFunction,AggregateMode,Complete,Count,Average,Sum,First,Max,Min} - +import org.joda.time.Period import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.StringTrimLeft import org.apache.spark.sql.catalyst.expressions.StartsWith import org.apache.spark.sql.catalyst.expressions.Contains import mimir.algebra.function.FunctionRegistry -import mimir.algebra.function.NativeFunction +import mimir.algebra.function.{NativeFunction,FoldFunction,ExpressionFunction} import mimir.sql.SparkBackend import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -68,13 +68,9 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.CreateArray import org.apache.spark.sql.types.TypeCollection -object OperatorTranslation +class OperatorTranslation(db: mimir.Database) extends LazyLogging { - var db: mimir.Database = null - def apply(db: mimir.Database) = { - this.db = db - } def mimirOpToSparkOp(oper:Operator) : LogicalPlan = { oper match { @@ -197,12 +193,13 @@ object OperatorTranslation //LogicalRelation(baseRelation, table) } case View(name, query, annotations) => { - val schema = db.typechecker.schemaOf(query) + val schema = db.typechecker.baseSchemaOf(query) val table = CatalogTable( TableIdentifier(name), CatalogTableType.VIEW, CatalogStorageFormat.empty, - mimirSchemaToStructType(schema) ) + OperatorTranslation.mimirSchemaToStructType(schema) + ) org.apache.spark.sql.catalyst.plans.logical.View( table, schema.map(col => { AttributeReference(col._1, getSparkType(col._2), true, Metadata.empty)( ) @@ -210,12 +207,13 @@ object OperatorTranslation mimirOpToSparkOp(query)) } case av@AdaptiveView(schemaName, name, query, annotations) => { - val schema = db.typechecker.schemaOf(av) + val schema = db.typechecker.baseSchemaOf(av) val table = CatalogTable( TableIdentifier(name), CatalogTableType.VIEW, CatalogStorageFormat.empty, - mimirSchemaToStructType(schema)) + OperatorTranslation.mimirSchemaToStructType(schema) + ) org.apache.spark.sql.catalyst.plans.logical.View( table, schema.map(col => { AttributeReference(col._1, getSparkType(col._2), true, Metadata.empty)( ) @@ -226,8 +224,13 @@ object OperatorTranslation //UnresolvedInlineTable( schema.unzip._1, data.map(row => row.map(mimirExprToSparkExpr(oper,_)))) /*LocalRelation(mimirSchemaToStructType(schema).map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()), data.map(row => InternalRow(row.map(mimirPrimitiveToSparkInternalRowValue(_)):_*)))*/ - LocalRelation.fromExternalRows(mimirSchemaToStructType(schema).map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()), - data.map(row => Row(row.map(mimirPrimitiveToSparkExternalRowValue(_)):_*))) + LocalRelation.fromExternalRows( + OperatorTranslation.mimirSchemaToStructType(schema.map { case (name, t) => (name, db.types.rootType(t)) }) + .map{ f => + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + }, + data.map(row => Row(row.map(OperatorTranslation.mimirPrimitiveToSparkExternalRowValue(_)):_*)) + ) } case Sort(sortCols, src) => { org.apache.spark.sql.catalyst.plans.logical.Sort( @@ -384,7 +387,7 @@ object OperatorTranslation }*/ def getPrivateMamber[T](inst:AnyRef, fieldName:String) : T = { - def _parents: Stream[Class[_]] = Stream(inst.getClass) #::: _parents.map(_.getSuperclass) + def _parents: Stream[Class[_]] = Stream[Class[_]](inst.getClass) #::: _parents.map(_.getSuperclass) val parents = _parents.takeWhile(_ != null).toList val fields = parents.flatMap(_.getDeclaredFields()) val field = fields.find(_.getName == fieldName).getOrElse(throw new IllegalArgumentException("Field " + fieldName + " not found")) @@ -396,7 +399,7 @@ object OperatorTranslation class PrivateMethodCaller(x: AnyRef, methodName: String) { def apply(_args: Any*): Any = { val args = _args.map(_.asInstanceOf[AnyRef]) - def _parents: Stream[Class[_]] = Stream(x.getClass) #::: _parents.map(_.getSuperclass) + def _parents: Stream[Class[_]] = Stream[Class[_]](x.getClass) #::: _parents.map(_.getSuperclass) val parents = _parents.takeWhile(_ != null).toList val methods = parents.flatMap(_.getDeclaredMethods) val method = methods.find(_.getName == methodName).getOrElse(throw new IllegalArgumentException("Method " + methodName + " not found")) @@ -512,7 +515,7 @@ object OperatorTranslation def mimirExprToSparkExpr(oper:Operator, expr:Expression) : org.apache.spark.sql.catalyst.expressions.Expression = { expr match { case primitive : PrimitiveValue => { - mimirPrimitiveToSparkPrimitive(primitive) + OperatorTranslation.mimirPrimitiveToSparkPrimitive(primitive) } case cmp@Comparison(op,lhs,rhs) => { mimirComparisonToSparkComparison(oper, cmp) @@ -542,22 +545,31 @@ object OperatorTranslation case BestGuess(model, idx, args, hints) => { val name = model.name //println(s"-------------------Translate BestGuess VGTerm($name, $idx, (${args.mkString(",")}), (${hints.mkString(",")}))") - BestGuessUDF(oper, model, idx, args, hints).getUDF + BestGuessUDF(oper, model, idx, + args.map { mimirExprToSparkExpr(oper, _) }, + hints.map { mimirExprToSparkExpr(oper, _) } + ).getUDF //UnresolvedFunction(mimir.ctables.CTables.FN_BEST_GUESS, mimirExprToSparkExpr(oper,StringPrimitive(name)) +: mimirExprToSparkExpr(oper,IntPrimitive(idx)) +: (args.map(mimirExprToSparkExpr(oper,_)) ++ hints.map(mimirExprToSparkExpr(oper,_))), true ) } case IsAcknowledged(model, idx, args) => { val name = model.name //println(s"-------------------Translate IsAcknoledged VGTerm($name, $idx, (${args.mkString(",")}))") - AckedUDF(oper, model, idx, args).getUDF + AckedUDF(oper, model, idx, args.map { mimirExprToSparkExpr(oper, _) }).getUDF //UnresolvedFunction(mimir.ctables.CTables.FN_IS_ACKED, mimirExprToSparkExpr(oper,StringPrimitive(name)) +: mimirExprToSparkExpr(oper,IntPrimitive(idx)) +: (args.map(mimirExprToSparkExpr(oper,_)) ), true ) } case Sampler(model, idx, args, hints, seed) => { - SampleUDF(oper, model, idx, seed, args, hints).getUDF + SampleUDF(oper, model, idx, seed, + args.map { mimirExprToSparkExpr(oper, _) }, + hints.map { mimirExprToSparkExpr(oper, _) } + ).getUDF } case VGTerm(name, idx, args, hints) => { //default to best guess //println(s"-------------------Translate VGTerm($name, $idx, (${args.mkString(",")}), (${hints.mkString(",")}))") val model = db.models.get(name) - BestGuessUDF(oper, model, idx, args, hints).getUDF + BestGuessUDF(oper, model, idx, + args.map { mimirExprToSparkExpr(oper, _) }, + hints.map { mimirExprToSparkExpr(oper, _) } + ).getUDF //UnresolvedFunction(mimir.ctables.CTables.FN_BEST_GUESS, mimirExprToSparkExpr(oper,StringPrimitive(name)) +: mimirExprToSparkExpr(oper,IntPrimitive(idx)) +: (args.map(mimirExprToSparkExpr(oper,_)) ++ hints.map(mimirExprToSparkExpr(oper,_))), true ) } case IsNullExpression(iexpr) => { @@ -571,64 +583,7 @@ object OperatorTranslation } } } - - val defaultDate = DateTimeUtils.toJavaDate(0) - val defaultTimestamp = DateTimeUtils.toJavaTimestamp(0L) - def mimirPrimitiveToSparkPrimitive(primitive : PrimitiveValue) : Literal = { - primitive match { - case NullPrimitive() => Literal(null) - case RowIdPrimitive(s) => Literal(s) - case StringPrimitive(s) => Literal(s) - case IntPrimitive(i) => Literal(i) - case FloatPrimitive(f) => Literal(f) - case BoolPrimitive(b) => Literal(b) - case ts@TimestampPrimitive(y,m,d,h,mm,s,ms) => Literal.create(SparkUtils.convertTimestamp(ts), TimestampType) - case dt@DatePrimitive(y,m,d) => Literal.create(SparkUtils.convertDate(dt), DateType) - case x => Literal(x.asString) - } - } - - def mimirPrimitiveToSparkInternalRowValue(primitive : PrimitiveValue) : Any = { - primitive match { - case NullPrimitive() => null - case RowIdPrimitive(s) => UTF8String.fromString(s) - case StringPrimitive(s) => UTF8String.fromString(s) - case IntPrimitive(i) => i - case FloatPrimitive(f) => f - case BoolPrimitive(b) => b - case ts@TimestampPrimitive(y,m,d,h,mm,s,ms) => SparkUtils.convertTimestamp(ts)//DateTimeUtils.fromJavaTimestamp(SparkUtils.convertTimestamp(ts)) - case dt@DatePrimitive(y,m,d) => SparkUtils.convertDate(dt)//DateTimeUtils.fromJavaDate(SparkUtils.convertDate(dt)) - case x => UTF8String.fromString(x.asString) - } - } - - def mimirPrimitiveToSparkExternalRowValue(primitive : PrimitiveValue) : Any = { - primitive match { - case NullPrimitive() => null - case RowIdPrimitive(s) => s - case StringPrimitive(s) => s - case IntPrimitive(i) => i - case FloatPrimitive(f) => f - case BoolPrimitive(b) => b - case ts@TimestampPrimitive(y,m,d,h,mm,s,ms) => SparkUtils.convertTimestamp(ts) - case dt@DatePrimitive(y,m,d) => SparkUtils.convertDate(dt) - case x => x.asString - } - } - - def mimirPrimitiveToSparkInternalInlineFuncParam(primitive : PrimitiveValue) : Any = { - primitive match { - case IntPrimitive(i) => i.toInt - case x => mimirPrimitiveToSparkInternalRowValue(x) - } - } - - def mimirPrimitiveToSparkExternalInlineFuncParam(primitive : PrimitiveValue) : Any = { - primitive match { - case IntPrimitive(i) => i.toInt - case x => mimirPrimitiveToSparkExternalRowValue(x) - } - } + def mimirComparisonToSparkComparison(oper:Operator, cmp:Comparison) : org.apache.spark.sql.catalyst.expressions.Expression = { cmp.op match { @@ -690,11 +645,25 @@ object OperatorTranslation throw new Exception(s"Function Translation not implemented $vgtBGFunc(${params.mkString(",")})") } case Function(name, params) => { - FunctionUDF(oper, name, db.functions.get(name), params, params.map(arg => db.typechecker.typeOf(arg, oper))).getUDF + FunctionUDF(oper, name, db.functions.get(name), + params.map { mimirExprToSparkExpr(oper, _) }, + params.map { db.typechecker.typeOf(_, oper) } + .map { db.types.rootType(_) } + ).getUDF } } } + def getSparkType(t:Type) : DataType = + OperatorTranslation.getSparkType(db.types.rootType(t)) +} + +object OperatorTranslation { + + val defaultDate = DateTimeUtils.toJavaDate(0) + val defaultTimestamp = DateTimeUtils.toJavaTimestamp(0L) + val defaultInterval = new Period() + def dataTypeFromString(dataTypeString:String): DataType = { dataTypeString match { case "BinaryType" => BinaryType @@ -732,15 +701,15 @@ object OperatorTranslation } } - def mimirSchemaToStructType(schema:Seq[(String, Type)]):StructType = { + def mimirSchemaToStructType(schema:Seq[(String, BaseType)]):StructType = { StructType(schema.map(col => StructField(col._1, getSparkType(col._2), true))) } - def structTypeToMimirSchema(schema:StructType): Seq[(String, Type)] = { - schema.fields.map(col => (col.name, getMimirType(col.dataType))) + def structTypeToMimirSchema(schema:StructType): Seq[(String, BaseType)] = { + schema.fields.map(col => (col.name, OperatorTranslation.getMimirType(col.dataType))) } - - def getSparkType(t:Type) : DataType = { + + def getSparkType(t:BaseType) : DataType = { t match { case TInt() => LongType case TFloat() => DoubleType @@ -752,8 +721,6 @@ object OperatorTranslation case TAny() => StringType case TTimestamp() => TimestampType case TInterval() => StringType - case TUser(name) => getSparkType(mimir.algebra.TypeRegistry.registeredTypes(name)._2) - case _ => StringType } } @@ -770,7 +737,7 @@ object OperatorTranslation } } - def getMimirType(dataType: DataType): Type = { + def getMimirType(dataType: DataType): BaseType = { dataType match { case IntegerType => TInt() case DoubleType => TFloat() @@ -783,7 +750,7 @@ object OperatorTranslation } } - def getNative(value:PrimitiveValue, t:Type): Any = { + def getNative(value:PrimitiveValue, t:BaseType): Any = { value match { case NullPrimitive() => t match { case TInt() => 0L @@ -796,7 +763,6 @@ object OperatorTranslation case TAny() => "" case TTimestamp() => OperatorTranslation.defaultTimestamp case TInterval() => "" - case TUser(name) => getNative(value, mimir.algebra.TypeRegistry.registeredTypes(name)._2) case x => "" } case RowIdPrimitive(s) => s @@ -817,6 +783,62 @@ object OperatorTranslation case _ => oper.children.map(extractTables(_)).flatten } } + + def mimirPrimitiveToSparkPrimitive(primitive : PrimitiveValue) : Literal = { + primitive match { + case NullPrimitive() => Literal(null) + case RowIdPrimitive(s) => Literal(s) + case StringPrimitive(s) => Literal(s) + case IntPrimitive(i) => Literal(i) + case FloatPrimitive(f) => Literal(f) + case BoolPrimitive(b) => Literal(b) + case ts@TimestampPrimitive(y,m,d,h,mm,s,ms) => Literal.create(SparkUtils.convertTimestamp(ts), TimestampType) + case dt@DatePrimitive(y,m,d) => Literal.create(SparkUtils.convertDate(dt), DateType) + case x => Literal(x.asString) + } + } + + def mimirPrimitiveToSparkInternalRowValue(primitive : PrimitiveValue) : Any = { + primitive match { + case NullPrimitive() => null + case RowIdPrimitive(s) => UTF8String.fromString(s) + case StringPrimitive(s) => UTF8String.fromString(s) + case IntPrimitive(i) => i + case FloatPrimitive(f) => f + case BoolPrimitive(b) => b + case ts@TimestampPrimitive(y,m,d,h,mm,s,ms) => SparkUtils.convertTimestamp(ts)//DateTimeUtils.fromJavaTimestamp(SparkUtils.convertTimestamp(ts)) + case dt@DatePrimitive(y,m,d) => SparkUtils.convertDate(dt)//DateTimeUtils.fromJavaDate(SparkUtils.convertDate(dt)) + case x => UTF8String.fromString(x.asString) + } + } + + def mimirPrimitiveToSparkExternalRowValue(primitive : PrimitiveValue) : Any = { + primitive match { + case NullPrimitive() => null + case RowIdPrimitive(s) => s + case StringPrimitive(s) => s + case IntPrimitive(i) => i + case FloatPrimitive(f) => f + case BoolPrimitive(b) => b + case ts@TimestampPrimitive(y,m,d,h,mm,s,ms) => SparkUtils.convertTimestamp(ts) + case dt@DatePrimitive(y,m,d) => SparkUtils.convertDate(dt) + case x => x.asString + } + } + + def mimirPrimitiveToSparkInternalInlineFuncParam(primitive : PrimitiveValue) : Any = { + primitive match { + case IntPrimitive(i) => i.toInt + case x => mimirPrimitiveToSparkInternalRowValue(x) + } + } + + def mimirPrimitiveToSparkExternalInlineFuncParam(primitive : PrimitiveValue) : Any = { + primitive match { + case IntPrimitive(i) => i.toInt + case x => mimirPrimitiveToSparkExternalRowValue(x) + } + } /*def mimirOpToDF(sqlContext:SQLContext, oper:Operator) : DataFrame = { val sparkOper = OperatorTranslation.mimirOpToSparkOp(oper) @@ -859,7 +881,7 @@ object OperatorTranslation } class MimirUDF { - def getPrimitive(t:Type, value:Any) = value match { + def getPrimitive(t:BaseType, value:Any) = value match { case null => NullPrimitive() case _ => t match { //case TInt() => IntPrimitive(value.asInstanceOf[Long]) @@ -870,10 +892,9 @@ class MimirUDF { case TString() => StringPrimitive(value.asInstanceOf[String]) case TBool() => BoolPrimitive(value.asInstanceOf[Boolean]) case TRowId() => RowIdPrimitive(value.asInstanceOf[String]) - case TType() => TypePrimitive(Type.fromString(value.asInstanceOf[String])) - //case TAny() => NullPrimitive() - //case TUser(name) => name.toLowerCase - //case TInterval() => Primitive(value.asInstanceOf[Long]) + case TType() => BaseType.fromString(value.asInstanceOf[String]) + .map { TypePrimitive(_) } + .getOrElse { NullPrimitive() } case _ => StringPrimitive(value.asInstanceOf[String]) } } @@ -895,14 +916,14 @@ class MimirUDF { } -case class BestGuessUDF(oper:Operator, model:Model, idx:Int, args:Seq[Expression], hints:Seq[Expression]) extends MimirUDF { +case class BestGuessUDF(oper:Operator, model:Model, idx:Int, args:Seq[org.apache.spark.sql.catalyst.expressions.Expression], hints:Seq[org.apache.spark.sql.catalyst.expressions.Expression]) extends MimirUDF { val sparkVarType = OperatorTranslation.getSparkType(model.varType(idx, model.argTypes(idx))) - val sparkArgs = (args.map(arg => OperatorTranslation.mimirExprToSparkExpr(oper,arg)) ++ hints.map(hint => OperatorTranslation.mimirExprToSparkExpr(oper,hint))).toList.toSeq - val sparkArgTypes = (model.argTypes(idx).map(arg => OperatorTranslation.getSparkType(arg)) ++ model.hintTypes(idx).map(hint => OperatorTranslation.getSparkType(hint))).toList.toSeq + val allArgs = (args ++ hints).toList.toSeq + val allArgTypes = (model.argTypes(idx).map(arg => OperatorTranslation.getSparkType(arg)) ++ model.hintTypes(idx).map(hint => OperatorTranslation.getSparkType(hint))).toList.toSeq def extractArgsAndHints(args:Seq[Any]) : (Seq[PrimitiveValue],Seq[PrimitiveValue]) ={ try{ - val getParamPrimitive:(Type, Any) => PrimitiveValue = if(sparkArgs.length > 22) (t, v) => { + val getParamPrimitive:(BaseType, Any) => PrimitiveValue = if(allArgs.length > 22) (t, v) => { v match { case null => NullPrimitive() case _ => Cast(t,StringPrimitive(v.asInstanceOf[String])) @@ -919,7 +940,7 @@ case class BestGuessUDF(oper:Operator, model:Model, idx:Int, args:Seq[Expression map( arg => getParamPrimitive(arg._1, args(argList.length+arg._2))) (argList,hintList) }catch { - case t: Throwable => throw new Exception(s"BestGuessUDF Error Extracting Args and Hints: \n\tModel: ${model.name} \n\tArgs: [${args.mkString(",")}] \n\tSparkArgs: [${sparkArgs.mkString(",")}]", t) + case t: Throwable => throw new Exception(s"BestGuessUDF Error Extracting Args and Hints: \n\tModel: ${model.name} \n\tArgs: [${args.mkString(",")}] \n\tSparkArgs: [${allArgs.mkString(",")}]", t) } } def varArgs(args:Any*):Any = { @@ -930,7 +951,7 @@ case class BestGuessUDF(oper:Operator, model:Model, idx:Int, args:Seq[Expression } def getUDF = ScalaUDF( - sparkArgs.length match { + allArgs.length match { case 0 => () => { getNative(model.bestGuess(idx, Seq(), Seq())) } @@ -1025,15 +1046,15 @@ case class BestGuessUDF(oper:Operator, model:Model, idx:Int, args:Seq[Expression case x => varArgs _ }, sparkVarType, - if(sparkArgs.length > 22) Seq(CreateArray(sparkArgs)) else sparkArgs, - if(sparkArgs.length > 22) Seq(ArrayType(StringType)) else sparkArgTypes, + if(allArgs.length > 22) Seq(CreateArray(allArgs)) else allArgs, + if(allArgs.length > 22) Seq(ArrayType(StringType)) else allArgTypes, Some(model.name)) } -case class SampleUDF(oper:Operator, model:Model, idx:Int, seed:Expression, args:Seq[Expression], hints:Seq[Expression]) extends MimirUDF { +case class SampleUDF(oper:Operator, model:Model, idx:Int, seed:Expression, args:Seq[org.apache.spark.sql.catalyst.expressions.Expression], hints:Seq[org.apache.spark.sql.catalyst.expressions.Expression]) extends MimirUDF { val sparkVarType = OperatorTranslation.getSparkType(model.varType(idx, model.argTypes(idx))) - val sparkArgs = (args.map(arg => OperatorTranslation.mimirExprToSparkExpr(oper,arg)) ++ hints.map(hint => OperatorTranslation.mimirExprToSparkExpr(oper,hint))).toList.toSeq - val sparkArgTypes = (model.argTypes(idx).map(arg => OperatorTranslation.getSparkType(arg)) ++ model.hintTypes(idx).map(hint => OperatorTranslation.getSparkType(hint))).toList.toSeq + val allArgs = (args ++ hints).toList.toSeq + val allArgTypes = (model.argTypes(idx).map(arg => OperatorTranslation.getSparkType(arg)) ++ model.hintTypes(idx).map(hint => OperatorTranslation.getSparkType(hint))).toList.toSeq def extractArgsAndHintsSeed(args:Seq[Any]) : (Long, Seq[PrimitiveValue],Seq[PrimitiveValue]) ={ try{ @@ -1048,7 +1069,7 @@ case class SampleUDF(oper:Operator, model:Model, idx:Int, seed:Expression, args: map( arg => getPrimitive(arg._1, args(argList.length+arg._2))) (seedp, argList,hintList) }catch { - case t: Throwable => throw new Exception(s"SampleUDF Error Extracting Args and Hints: \n\tModel: ${model.name} \n\tArgs: [${args.mkString(",")}] \n\tSparkArgs: [${sparkArgs.mkString(",")}]", t) + case t: Throwable => throw new Exception(s"SampleUDF Error Extracting Args and Hints: \n\tModel: ${model.name} \n\tArgs: [${args.mkString(",")}] \n\tSparkArgs: [${allArgs.mkString(",")}]", t) } } def varArgs(args:Any*): Any = { @@ -1058,7 +1079,7 @@ case class SampleUDF(oper:Operator, model:Model, idx:Int, seed:Expression, args: } def getUDF = ScalaUDF( - sparkArgs.length match { + allArgs.length match { case 0 => () => { getNative(model.sample(idx, 0, Seq(), Seq())) } @@ -1153,21 +1174,21 @@ case class SampleUDF(oper:Operator, model:Model, idx:Int, seed:Expression, args: case x => varArgs _ }, sparkVarType, - if(sparkArgs.length > 22) Seq(CreateStruct(sparkArgs)) else sparkArgs, - if(sparkArgs.length > 22) Seq(getStructType(sparkArgTypes)) else sparkArgTypes, + if(allArgs.length > 22) Seq(CreateStruct(allArgs)) else allArgs, + if(allArgs.length > 22) Seq(getStructType(allArgTypes)) else allArgTypes, Some(model.name)) } -case class AckedUDF(oper:Operator, model:Model, idx:Int, args:Seq[Expression]) extends MimirUDF { - val sparkArgs = (args.map(arg => OperatorTranslation.mimirExprToSparkExpr(oper,arg))).toList.toSeq - val sparkArgTypes = (model.argTypes(idx).map(arg => OperatorTranslation.getSparkType(arg))).toList.toSeq +case class AckedUDF(oper:Operator, model:Model, idx:Int, args:Seq[org.apache.spark.sql.catalyst.expressions.Expression]) extends MimirUDF { + val allArgs = args.toList.toSeq + val allArgTypes = (model.argTypes(idx).map(arg => OperatorTranslation.getSparkType(arg))).toList.toSeq def extractArgs(args:Seq[Any]) : Seq[PrimitiveValue] = { try{ model.argTypes(idx). zipWithIndex. map( arg => getPrimitive(arg._1, args(arg._2))) }catch { - case t: Throwable => throw new Exception(s"AckedUDF Error Extracting Args: \n\tModel: ${model.name} \n\tArgs: [${args.mkString(",")}] \n\tSparkArgs: [${sparkArgs.mkString(",")}]", t) + case t: Throwable => throw new Exception(s"AckedUDF Error Extracting Args: \n\tModel: ${model.name} \n\tArgs: [${args.mkString(",")}] \n\tSparkArgs: [${allArgs.mkString(",")}]", t) } } def varArgs(args:Any*):Any = { @@ -1177,7 +1198,7 @@ case class AckedUDF(oper:Operator, model:Model, idx:Int, args:Seq[Expression]) e } def getUDF = ScalaUDF( - sparkArgs.length match { + allArgs.length match { case 0 => () => { new java.lang.Boolean(model.isAcknowledged(idx, Seq())) } @@ -1272,22 +1293,26 @@ case class AckedUDF(oper:Operator, model:Model, idx:Int, args:Seq[Expression]) e case x => varArgs _ }, BooleanType, - if(sparkArgs.length > 22) Seq(CreateStruct(sparkArgs)) else sparkArgs, - if(sparkArgs.length > 22) Seq(getStructType(sparkArgTypes)) else sparkArgTypes, + if(allArgs.length > 22) Seq(CreateStruct(allArgs)) else allArgs, + if(allArgs.length > 22) Seq(getStructType(allArgTypes)) else allArgTypes, Some(model.name)) } -case class FunctionUDF(oper:Operator, name:String, function:RegisteredFunction, params:Seq[Expression], argTypes:Seq[Type]) extends MimirUDF { - val sparkArgs = (params.map(arg => OperatorTranslation.mimirExprToSparkExpr(oper,arg))).toList.toSeq - val sparkArgTypes = argTypes.map(argT => OperatorTranslation.getSparkType(argT)).toList.toSeq - val dataType = function match { case NativeFunction(_, _, tc, _) => OperatorTranslation.getSparkType(tc(argTypes)) } +case class FunctionUDF(oper:Operator, name:String, function:RegisteredFunction, params:Seq[org.apache.spark.sql.catalyst.expressions.Expression], argTypes:Seq[BaseType]) extends MimirUDF { + val allArgs = params.toList.toSeq + val allArgTypes = argTypes.map(argT => OperatorTranslation.getSparkType(argT)).toList.toSeq + val dataType = function match { + case NativeFunction(_, _, tc, _) => OperatorTranslation.getSparkType(tc(argTypes)) + case ExpressionFunction(_,_,_) => throw new Exception("Can't create Spark UDF for Expression Functions") + case FoldFunction(_,_) => throw new Exception("Can't create Spark UDF for Fold Functions") + } def extractArgs(args:Seq[Any]) : Seq[PrimitiveValue] = { try{ argTypes. zipWithIndex. map( arg => getPrimitive(arg._1, args(arg._2))) }catch { - case t: Throwable => throw new Exception(s"FunctionUDF Error Extracting Args: \n\tModel: ${name} \n\tArgs: [${args.mkString(",")}] \n\tSparkArgs: [${sparkArgs.mkString(",")}]", t) + case t: Throwable => throw new Exception(s"FunctionUDF Error Extracting Args: \n\tModel: ${name} \n\tArgs: [${args.mkString(",")}] \n\tSparkArgs: [${allArgs.mkString(",")}]", t) } } def varArgs(evaluator:Seq[PrimitiveValue] => PrimitiveValue)(args:Any*):Any = { @@ -1298,8 +1323,10 @@ case class FunctionUDF(oper:Operator, name:String, function:RegisteredFunction, def getUDF = ScalaUDF( function match { + case ExpressionFunction(_,_,_) => throw new Exception("Can't create Spark UDF for Expression Functions") + case FoldFunction(_,_) => throw new Exception("Can't create Spark UDF for Fold Functions") case NativeFunction(_, evaluator, typechecker, _) => - sparkArgs.length match { + allArgs.length match { case 0 => () => { getNative(evaluator(Seq())) } @@ -1395,8 +1422,8 @@ case class FunctionUDF(oper:Operator, name:String, function:RegisteredFunction, } }, dataType, - if(sparkArgs.length > 22) Seq(CreateStruct(sparkArgs)) else sparkArgs, - if(sparkArgs.length > 22) Seq(getStructType(sparkArgTypes)) else sparkArgTypes, + if(allArgs.length > 22) Seq(CreateStruct(allArgs)) else allArgs, + if(allArgs.length > 22) Seq(getStructType(allArgTypes)) else allArgTypes, Some(name)) } diff --git a/src/main/scala/mimir/algebra/spark/function/SparkFunctions.scala b/src/main/scala/mimir/algebra/spark/function/SparkFunctions.scala index e7aa2f29..1be2ff18 100644 --- a/src/main/scala/mimir/algebra/spark/function/SparkFunctions.scala +++ b/src/main/scala/mimir/algebra/spark/function/SparkFunctions.scala @@ -2,16 +2,16 @@ package mimir.algebra.spark.function import mimir.algebra.function.FunctionRegistry import mimir.algebra.PrimitiveValue -import mimir.algebra.Type +import mimir.algebra.BaseType import com.typesafe.scalalogging.slf4j.LazyLogging object SparkFunctions extends LazyLogging { - val sparkFunctions = scala.collection.mutable.Map[String, (Seq[PrimitiveValue] => PrimitiveValue, Seq[Type] => Type)]() + val sparkFunctions = scala.collection.mutable.Map[String, (Seq[PrimitiveValue] => PrimitiveValue, Seq[BaseType] => BaseType)]() - def addSparkFunction(fname:String,eval:Seq[PrimitiveValue] => PrimitiveValue, typechecker: Seq[Type] => Type) : Unit = { + def addSparkFunction(fname:String,eval:Seq[PrimitiveValue] => PrimitiveValue, typechecker: Seq[BaseType] => BaseType) : Unit = { sparkFunctions.put(fname, (eval, typechecker)) } diff --git a/src/main/scala/mimir/algebra/typeregistry/DefaultTypeRegistry.scala b/src/main/scala/mimir/algebra/typeregistry/DefaultTypeRegistry.scala new file mode 100644 index 00000000..f459e8db --- /dev/null +++ b/src/main/scala/mimir/algebra/typeregistry/DefaultTypeRegistry.scala @@ -0,0 +1,67 @@ +package mimir.algebra.typeregistry + +import mimir.algebra._ + +case class RegisteredType(name:String, constraints:Set[TypeConstraint], basedOn:Type = TString()) + +object DefaultTypeRegistry extends TypeRegistry with Serializable +{ + val types:Seq[RegisteredType] = Seq( + RegisteredType("credits", Set(RegexpConstraint("^[0-9]{1,3}(\\.[0-9]{0,2})?$".r))), + RegisteredType("email", Set(RegexpConstraint("^[a-z0-9._%+-]+@[a-z0-9.-]+\\.[a-z]{2,}$".r))), + RegisteredType("productid", Set(RegexpConstraint("^P\\d+$".r))), + RegisteredType("firecompany", Set(RegexpConstraint("^[a-zA-Z]\\d{3}$".r))), + RegisteredType("zipcode", Set(RegexpConstraint("^\\d{5}(?:[-\\s]\\d{4})?$".r))), + RegisteredType("container", Set(RegexpConstraint("^[A-Z]{4}[0-9]{7}$".r))), + RegisteredType("carriercode", Set(RegexpConstraint("^[A-Z]{4}$".r))), + RegisteredType("mmsi", Set(RegexpConstraint("^MID\\d{6}|0MID\\d{5}|00MID\\{4}$".r))), + RegisteredType("billoflanding", Set(RegexpConstraint("^[A-Z]{8}[0-9]{8}$".r))), + RegisteredType("imo_code", Set(RegexpConstraint("^\\d{7}$".r))) + ) + val typesByName = types.map { t => t.name -> t }.toMap + val typesByBaseType = types.groupBy { _.basedOn } + val indexesByName = types.zipWithIndex.map { t => t._1.name -> t._2}.toMap + + def getDefinition(name:String): RegisteredType = + { + typesByName.get(name) match { + case Some(definition) => definition + case None => throw new RAException(s"Undefined user defined type: $name") + } + } + def supportsUserType(name:String): Boolean = + typesByName contains name + + def parentOfUserType(t: TUser): Type = + getDefinition(t.name).basedOn + + def testForUserTypes(value: String, validBaseTypes:Set[BaseType]): Set[TUser] = + { + // First find base types that match + + // And then find user-defined types that match + validBaseTypes + .map { + typesByBaseType.get(_).toSeq.flatten + .filter { _.constraints.forall { _.test(value) } } + .map { t => TUser(t.name) } + } + .flatten + .toSet + } + def userTypeCaster(t: TUser, target:Expression, orElse: Expression): Expression = + { + Conditional( + ExpressionUtils.makeAnd( + getDefinition(t.name) + .constraints + .map { _.tester(target) } + ), + target, + orElse + ) + } + def userTypeForId(i: Integer) = TUser(types(i).name) + def idForUserType(t: TUser): Integer = indexesByName(t.name) + def getSerializable = this +} \ No newline at end of file diff --git a/src/main/scala/mimir/algebra/typeregistry/TypeConstraint.scala b/src/main/scala/mimir/algebra/typeregistry/TypeConstraint.scala new file mode 100644 index 00000000..d83d25c6 --- /dev/null +++ b/src/main/scala/mimir/algebra/typeregistry/TypeConstraint.scala @@ -0,0 +1,49 @@ +package mimir.algebra.typeregistry + +import mimir.algebra._ +import scala.util.matching.Regex +import mimir.util.TextUtils + +sealed abstract class TypeConstraint +{ + def test(v: String): Boolean + def tester(target:Expression): Expression +} + +case class RegexpConstraint(matcher: Regex) extends TypeConstraint +{ + def test(v:String) = { (matcher findFirstMatchIn v) != None } + def tester(target:Expression) = Function("rlike", Seq(target, StringPrimitive(matcher.regex))) +} + +case class EnumConstraint(values: Set[String], caseSensitive: Boolean = false) extends TypeConstraint +{ + lazy val caseSensitiveValues = + if(caseSensitive) { values } + else { values.map { _.toLowerCase } } + + def test(v:String) = + if(caseSensitive) { values contains v } + else { values contains v.toLowerCase } + def tester(target:Expression) = + ExpressionUtils.makeOr( + caseSensitiveValues.map { v => + Comparison(Cmp.Eq, target, StringPrimitive(v)) + } + ) + def base = TString() +} + +case class IntExpressionConstraint(constraint: Expression, targetType: BaseType) extends TypeConstraint +{ + lazy val eval = new Eval() + + def test(v: String) = + eval.evalBool(constraint, Map( + "x" -> TextUtils.parsePrimitive(targetType, v) + )) + def tester(target:Expression) = + Eval.inline(constraint, Map( + "x" -> target + )) +} \ No newline at end of file diff --git a/src/main/scala/mimir/algebra/typeregistry/TypeRegistry.scala b/src/main/scala/mimir/algebra/typeregistry/TypeRegistry.scala new file mode 100644 index 00000000..149a0b7e --- /dev/null +++ b/src/main/scala/mimir/algebra/typeregistry/TypeRegistry.scala @@ -0,0 +1,66 @@ +package mimir.algebra.typeregistry + +import mimir.algebra._ + +abstract class TypeRegistry +{ + def userTypeForId(i: Integer): TUser + def idForUserType(t: TUser): Integer + def supportsUserType(name:String): Boolean + + def testForUserTypes(record: String, validBaseTypes:Set[BaseType]): Set[TUser] + def userTypeCaster(t:TUser, target: Expression, orElse: Expression): Expression + + def parentOfUserType(t:TUser): Type + def rootType(t:Type): BaseType = + t match { + case u:TUser => rootType(parentOfUserType(u)); + case b:BaseType => b + } + def parentType(t:Type): Type = + t match { + case u:TUser => parentOfUserType(u) + case _:BaseType => t + } + + def testForTypes(value:String): Set[Type] = + { + val validBaseTypes = + BaseType.tests + .filter { _._2.findFirstMatchIn(value) != None } + .map { _._1 } + .toSet + testForUserTypes(value, validBaseTypes) ++ validBaseTypes + } + def typeCaster(t: Type, target: Expression, orElse: Expression = NullPrimitive()): Expression = + { + val castTarget = Function("CAST", Seq(target, TypePrimitive(rootType(t)))) + t match { + case u:TUser => userTypeCaster(u, castTarget, orElse) + case b:BaseType => castTarget + } + } + + def fromString(name:String) = + { + BaseType.fromString(name) + .getOrElse { + if(supportsUserType(name)){ TUser(name) } + else { throw new RAException("Unsupported Type: "+name) } + } + } + + val idForBaseType = BaseType.idTypeOrder.zipWithIndex.toMap + + def typeForId(i:Integer) = + if(i < BaseType.idTypeOrder.size){ BaseType.idTypeOrder(i) } + else { userTypeForId(i - BaseType.idTypeOrder.size) } + def idForType(t:Type):Integer = + t match { + case u:TUser => idForUserType(u) + case b:BaseType => idForBaseType(b) + } + + def getSerializable:(TypeRegistry with Serializable) + +} diff --git a/src/main/scala/mimir/ctables/CTExplainer.scala b/src/main/scala/mimir/ctables/CTExplainer.scala index cc77e9df..8e35a733 100644 --- a/src/main/scala/mimir/ctables/CTExplainer.scala +++ b/src/main/scala/mimir/ctables/CTExplainer.scala @@ -339,7 +339,7 @@ class CTExplainer(db: Database) extends LazyLogging { val optQuery = db.compiler.optimize(inlinedQuery) - val finalSchema = db.typechecker.schemaOf(optQuery) + val finalSchema = db.typechecker.baseSchemaOf(optQuery) //val sqlQuery = db.ra.convert(optQuery) @@ -348,7 +348,10 @@ class CTExplainer(db: Database) extends LazyLogging { val results = db.backend.execute(optQuery)//sqlQuery) val baseData = - SparkUtils.extractAllRows(results, finalSchema.map(_._2)).flush + SparkUtils.extractAllRows( + results, + finalSchema.map { _._2 } + ).flush if(baseData.isEmpty){ val resultRowString = baseData.map( _.mkString(", ") ).mkString("\n") diff --git a/src/main/scala/mimir/ctables/CTPercolator.scala b/src/main/scala/mimir/ctables/CTPercolator.scala index 6a242350..0b288b21 100644 --- a/src/main/scala/mimir/ctables/CTPercolator.scala +++ b/src/main/scala/mimir/ctables/CTPercolator.scala @@ -3,6 +3,7 @@ package mimir.ctables import java.sql.SQLException import com.typesafe.scalalogging.slf4j.LazyLogging +import mimir.Database import mimir.algebra._ import mimir.util._ import mimir.optimizer._ @@ -597,8 +598,8 @@ object CTPercolator } } - def percolateGProM(oper: Operator): (Operator, Map[String,Expression], Expression) = + def percolateGProM(oper: Operator, db: Database): (Operator, Map[String,Expression], Expression) = { - mimir.algebra.gprom.OperatorTranslation.compileTaintWithGProM(oper) + db.gpromTranslator.compileTaintWithGProM(oper) } } diff --git a/src/main/scala/mimir/ctables/CTPercolatorClassic.scala b/src/main/scala/mimir/ctables/CTPercolatorClassic.scala index 7b7fbef6..a3c89226 100644 --- a/src/main/scala/mimir/ctables/CTPercolatorClassic.scala +++ b/src/main/scala/mimir/ctables/CTPercolatorClassic.scala @@ -264,6 +264,7 @@ object CTPercolatorClassic { case ProvenanceOf(psel) => { percolateOne(psel) } + case _ => ??? // This code is deprecated. It's only getting called for research use. } } @@ -284,7 +285,7 @@ object CTPercolatorClassic { case Union(_,_) => false case Join(_,_) => false case Aggregate(_,_,_) => false - case (Annotate(_, _) | ProvenanceOf(_) | Recover(_, _)) => ??? + case _ => ??? // This code is deprecated. It's only getting called for research use. } } @@ -382,7 +383,7 @@ object CTPercolatorClassic { Table(name, alias, sch, metadata) } - case (Annotate(_, _) | ProvenanceOf(_) | Recover(_, _)) => ??? + case _ => ??? // This code is deprecated. It's only getting called for research use. } } diff --git a/src/main/scala/mimir/ctables/vgterm/BestGuess.scala b/src/main/scala/mimir/ctables/vgterm/BestGuess.scala index 22b9dad9..28f6e889 100644 --- a/src/main/scala/mimir/ctables/vgterm/BestGuess.scala +++ b/src/main/scala/mimir/ctables/vgterm/BestGuess.scala @@ -11,7 +11,7 @@ case class BestGuess( vgHints: Seq[Expression] ) extends Proc(vgArgs++vgHints) { override def toString() = "{{ BEST GUESS: "+model.name+";"+idx+"["+vgArgs.mkString(", ")+"]["+vgHints.mkString(", ")+"] }}" - override def getType(bindings: Seq[Type]):Type = model.varType(idx, bindings) + override def getType(bindings: Seq[BaseType]):BaseType = model.varType(idx, bindings) override def children: Seq[Expression] = vgArgs ++ vgHints override def rebuild(x: Seq[Expression]) = { val (a, h) = x.splitAt(vgArgs.length) diff --git a/src/main/scala/mimir/ctables/vgterm/DomainDumper.scala b/src/main/scala/mimir/ctables/vgterm/DomainDumper.scala index 12c4b6ab..feda1042 100644 --- a/src/main/scala/mimir/ctables/vgterm/DomainDumper.scala +++ b/src/main/scala/mimir/ctables/vgterm/DomainDumper.scala @@ -13,7 +13,7 @@ case class DomainDumper( vgHints: Seq[Expression] ) extends Proc(vgArgs++vgHints) { override def toString() = "{{ DOMAIN DUMP: "+model.name+";"+idx+"["+vgArgs.mkString(", ")+"]["+vgHints.mkString(", ")+"] }}" - override def getType(bindings: Seq[Type]):Type = TString() + override def getType(bindings: Seq[BaseType]):BaseType = TString() override def children: Seq[Expression] = vgArgs ++ vgHints override def rebuild(x: Seq[Expression]) = { val (a, h) = x.splitAt(vgArgs.length) diff --git a/src/main/scala/mimir/ctables/vgterm/IsAcknowledged.scala b/src/main/scala/mimir/ctables/vgterm/IsAcknowledged.scala index 6dd46181..88fb5c25 100644 --- a/src/main/scala/mimir/ctables/vgterm/IsAcknowledged.scala +++ b/src/main/scala/mimir/ctables/vgterm/IsAcknowledged.scala @@ -12,7 +12,7 @@ case class IsAcknowledged( ) extends Proc( vgArgs ) { - def getType(argTypes: Seq[Type]): Type = TBool() + def getType(argTypes: Seq[BaseType]): BaseType = TBool() def get(v: Seq[PrimitiveValue]): PrimitiveValue = { BoolPrimitive(model.isAcknowledged(idx, v)) diff --git a/src/main/scala/mimir/ctables/vgterm/Sampler.scala b/src/main/scala/mimir/ctables/vgterm/Sampler.scala index 0cf06d40..ea9ea316 100644 --- a/src/main/scala/mimir/ctables/vgterm/Sampler.scala +++ b/src/main/scala/mimir/ctables/vgterm/Sampler.scala @@ -14,7 +14,7 @@ case class Sampler( ) extends Proc( (seed :: (vgArgs.toList ++ vgHints.toList)) ) { - def getType(argTypes: Seq[Type]): Type = + def getType(argTypes: Seq[BaseType]): BaseType = model.varType(idx, argTypes) def get(v: Seq[PrimitiveValue]): PrimitiveValue = { diff --git a/src/main/scala/mimir/exec/Compiler.scala b/src/main/scala/mimir/exec/Compiler.scala index 45e9afbf..2ad8466a 100644 --- a/src/main/scala/mimir/exec/Compiler.scala +++ b/src/main/scala/mimir/exec/Compiler.scala @@ -121,7 +121,7 @@ class Compiler(db: Database) extends LazyLogging { new AggregateResultIterator( gbCols, aggFunctions, - requiredColumnsInOrder.map { col => (col, sourceColumnTypes(col)) }, + requiredColumnsInOrder.map { col => (col, db.types.rootType(sourceColumnTypes(col))) }, jointIterator, db ) @@ -141,7 +141,7 @@ class Compiler(db: Database) extends LazyLogging { val (schema, rootIterator) = rootIteratorGen({ if(ExperimentalOptions.isEnabled("GPROM-OPTIMIZE") && db.backend.isInstanceOf[mimir.sql.GProMBackend] ) { - OperatorTranslation.optimizeWithGProM(oper) + db.gpromTranslator.optimizeWithGProM(oper) } else { optimize(oper) } @@ -152,7 +152,7 @@ class Compiler(db: Database) extends LazyLogging { new ProjectionResultIterator( outputCols.map( projections(_) ), annotationCols.map( projections(_) ).toSeq, - schema, + schema.map { case (name,t) => (name, db.types.rootType(t)) }, rootIterator, db ) @@ -164,6 +164,7 @@ class Compiler(db: Database) extends LazyLogging { (schema, new SparkResultIterator( schema, oper, db.backend, + db.types, db.backend.dateType )) } @@ -173,6 +174,7 @@ class Compiler(db: Database) extends LazyLogging { (sqlSchema, new JDBCResultIterator( sqlSchema, sql, db.metadataBackend, + db.types, db.backend.dateType )) } @@ -185,7 +187,7 @@ class Compiler(db: Database) extends LazyLogging { val optimized = { if(ExperimentalOptions.isEnabled("GPROM-OPTIMIZE") && db.backend.isInstanceOf[mimir.sql.GProMBackend] ) { - OperatorTranslation.optimizeWithGProM(oper) + db.gpromTranslator.optimizeWithGProM(oper) } else { optimize(oper) } diff --git a/src/main/scala/mimir/exec/EvalInlined.scala b/src/main/scala/mimir/exec/EvalInlined.scala index 6abbaed2..51b90eca 100644 --- a/src/main/scala/mimir/exec/EvalInlined.scala +++ b/src/main/scala/mimir/exec/EvalInlined.scala @@ -49,9 +49,9 @@ class EvalInlined[T](scope: Map[String, (Type, (T => PrimitiveValue))], db: Data case TType() => val v = compileForType(e); checkNull { (t:T) => TypePrimitive(v(t)) } case TDate() => checkNull { compileForDate(e) } case TTimestamp() => checkNull { compileForTimestamp(e) } - case TInterval() => checkNull { compileForInterval(e) } + case TInterval() => checkNull { compileForInterval(e) } case TRowId() => checkNull { compileForRowId(e) } - case TUser(ut) => checkNull { compile(e, TypeRegistry.baseType(ut)) } + case u:TUser => checkNull { compile(e, db.types.rootType(u)) } } } } @@ -191,11 +191,9 @@ class EvalInlined[T](scope: Map[String, (Type, (T => PrimitiveValue))], db: Data } } case Comparison(op, lhs, rhs) => { - (op, Type.rootType(typeOf(lhs)), Type.rootType(typeOf(rhs))) match { + (op, db.types.rootType(typeOf(lhs)), db.types.rootType(typeOf(rhs))) match { case (_, TAny(), _) => throw new RAException(s"Invalid comparison on TAny: $e") case (_, _, TAny()) => throw new RAException(s"Invalid comparison on TAny: $e") - case (_, TUser(n), _) => throw new RAException(s"Internal error in Type.rootType($n): $e") - case (_, _, TUser(n)) => throw new RAException(s"Internal error in Type.rootType($n): $e") case (Cmp.Eq, TBool(), TBool()) => compileBinary(lhs, rhs, compileForBool) { _ == _ } case (Cmp.Eq, TInt(), TInt()) => compileBinary(lhs, rhs, compileForLong) { _ == _ } case (Cmp.Eq, ( TInt() | TFloat() ), ( TInt() | TFloat() ) ) diff --git a/src/main/scala/mimir/exec/mode/BestGuess.scala b/src/main/scala/mimir/exec/mode/BestGuess.scala index a9ddf6b7..3e64b435 100644 --- a/src/main/scala/mimir/exec/mode/BestGuess.scala +++ b/src/main/scala/mimir/exec/mode/BestGuess.scala @@ -50,7 +50,7 @@ object BestGuess if(false && ExperimentalOptions.isEnabled("GPROM-DETERMINISM") && ExperimentalOptions.isEnabled("GPROM-PROVENANCE") && ExperimentalOptions.isEnabled("GPROM-BACKEND")){ - OperatorTranslation.compileProvenanceAndTaintWithGProM(oper) + db.gpromTranslator.compileProvenanceAndTaintWithGProM(oper) } else { // The names that the provenance compilation step assigns will @@ -60,7 +60,7 @@ object BestGuess val provenance = if(ExperimentalOptions.isEnabled("GPROM-PROVENANCE") && ExperimentalOptions.isEnabled("GPROM-BACKEND")) - { Provenance.compileGProM(oper) } + { db.gpromTranslator.compileProvenanceWithGProM(oper) } else { Provenance.compile(oper) } oper = provenance._1 @@ -72,7 +72,7 @@ object BestGuess // Tag rows/columns with provenance metadata val tagging = if(ExperimentalOptions.isEnabled("GPROM-DETERMINISM") && ExperimentalOptions.isEnabled("GPROM-BACKEND")) - { CTPercolator.percolateGProM(oper) } + { CTPercolator.percolateGProM(oper, db) } else { CTPercolator.percolateLite(oper, db.models.get(_)) } (tagging._1, provenanceCols, diff --git a/src/main/scala/mimir/exec/mode/DumpDomain.scala b/src/main/scala/mimir/exec/mode/DumpDomain.scala index a54301ae..9eb23157 100644 --- a/src/main/scala/mimir/exec/mode/DumpDomain.scala +++ b/src/main/scala/mimir/exec/mode/DumpDomain.scala @@ -55,7 +55,7 @@ object DumpDomain val provenance = if(ExperimentalOptions.isEnabled("GPROM-PROVENANCE") && db.backend.isInstanceOf[mimir.sql.GProMBackend]) - { Provenance.compileGProM(oper) } + { db.gpromTranslator.compileProvenanceWithGProM(oper) } else { Provenance.compile(oper) } oper = provenance._1 @@ -67,7 +67,7 @@ object DumpDomain // Tag rows/columns with provenance metadata val tagging = if(ExperimentalOptions.isEnabled("GPROM-DETERMINISM") && db.backend.isInstanceOf[mimir.sql.GProMBackend]) - { CTPercolator.percolateGProM(oper) } + { CTPercolator.percolateGProM(oper, db) } else { CTPercolator.percolateLite(oper, db.models.get(_)) } oper = tagging._1 val colDeterminism = tagging._2.filter( col => rawColumns(col._1) ) diff --git a/src/main/scala/mimir/exec/result/AggregateResultIterator.scala b/src/main/scala/mimir/exec/result/AggregateResultIterator.scala index 7bacda38..5f932264 100644 --- a/src/main/scala/mimir/exec/result/AggregateResultIterator.scala +++ b/src/main/scala/mimir/exec/result/AggregateResultIterator.scala @@ -101,32 +101,38 @@ class FirstAggregate(arg: Row => PrimitiveValue) extends AggregateValue class AggregateResultIterator( groupByColumns: Seq[Var], aggregates: Seq[AggFunction], - inputSchema: Seq[(String,Type)], + inputSchema: Seq[(String,BaseType)], src: ResultIterator, db: Database ) extends ResultIterator with LazyLogging { - private val typeOfInputColumn: Map[String, Type] = inputSchema.toMap + private val typeOfInputColumn: Map[String, BaseType] = inputSchema.toMap private val typeOf = db.typechecker.typeOf(_:Expression, scope = typeOfInputColumn) - val tupleSchema: Seq[(String,Type)] = + val tupleSchema: Seq[(String,BaseType)] = groupByColumns.map { col => (col.name, typeOfInputColumn(col.name)) } ++ aggregates.map { fn => ( fn.alias, - db.aggregates.typecheck(fn.function, fn.args.map { typeOf(_) }) + db.types.rootType( + db.aggregates.typecheck( + fn.function, + fn.args.map { typeOf(_) } + .map { db.types.rootType(_) } + ) + ) ) } - val annotationSchema: Seq[(String,Type)] = Seq() + val annotationSchema: Seq[(String,BaseType)] = Seq() val aggNames = aggregates.map { _.alias } val aggTypes = aggregates.map { agg => (agg, agg.args.map { typeOf(_) }) } - private val aggEvalScope: Map[String,(Type, Row => PrimitiveValue)] = + private val aggEvalScope: Map[String,(BaseType, Row => PrimitiveValue)] = inputSchema.zipWithIndex.map { case ((name, t), idx) => logger.debug(s"For $name ($t) using idx = $idx") diff --git a/src/main/scala/mimir/exec/result/JDBCResultIterator.scala b/src/main/scala/mimir/exec/result/JDBCResultIterator.scala index 65401e5d..f27afb6a 100644 --- a/src/main/scala/mimir/exec/result/JDBCResultIterator.scala +++ b/src/main/scala/mimir/exec/result/JDBCResultIterator.scala @@ -3,6 +3,7 @@ package mimir.exec.result import java.sql._ import com.typesafe.scalalogging.slf4j.LazyLogging import mimir.algebra._ +import mimir.algebra.typeregistry.TypeRegistry import mimir.util._ import mimir.exec._ import mimir.sql.MetadataBackend @@ -12,6 +13,7 @@ class JDBCResultIterator( inputSchema: Seq[(String,Type)], query: SelectBody, backend: MetadataBackend, + types: TypeRegistry, dateType: (Type) ) extends ResultIterator @@ -26,7 +28,7 @@ class JDBCResultIterator( zipWithIndex. map { case ((name, t), idx) => logger.debug(s"Extracting Raw: $name (@$idx) -> $t") - val fn = JDBCUtils.convertFunction(t, idx+1, dateType = dateType) + val fn = JDBCUtils.convertFunction(types.rootType(t), idx+1, dateType = dateType) () => { fn(source) } } diff --git a/src/main/scala/mimir/exec/result/SparkResultIterator.scala b/src/main/scala/mimir/exec/result/SparkResultIterator.scala index c087e960..388e1f81 100644 --- a/src/main/scala/mimir/exec/result/SparkResultIterator.scala +++ b/src/main/scala/mimir/exec/result/SparkResultIterator.scala @@ -3,6 +3,7 @@ package mimir.exec.result import org.apache.spark.sql.DataFrame import com.typesafe.scalalogging.slf4j.LazyLogging import mimir.algebra._ +import mimir.algebra.typeregistry.TypeRegistry import mimir.sql.RABackend import mimir.util.SparkUtils import mimir.util.Timer @@ -11,6 +12,7 @@ class SparkResultIterator( inputSchema: Seq[(String,Type)], query: Operator, backend: RABackend, + types: TypeRegistry, dateType: (Type) ) extends ResultIterator @@ -24,7 +26,7 @@ class SparkResultIterator( zipWithIndex. map { case ((name, t), idx) => logger.debug(s"Extracting Raw: $name (@$idx) -> $t") - val fn = SparkUtils.convertFunction(t, idx, dateType = dateType) + val fn = SparkUtils.convertFunction(types.rootType(t), idx, dateType = dateType) () => { fn(row) } } diff --git a/src/main/scala/mimir/lenses/CommentLens.scala b/src/main/scala/mimir/lenses/CommentLens.scala index 8bf2efdf..ad3cb644 100644 --- a/src/main/scala/mimir/lenses/CommentLens.scala +++ b/src/main/scala/mimir/lenses/CommentLens.scala @@ -45,7 +45,7 @@ object CommentLens { "COMMENT_ARG_"+index ( ProjectArg(outputCol, VGTerm(modelName, index, Seq(RowIdVar()), Seq(expr))), - (outputCol, db.typechecker.typeOf(expr, query)), + (outputCol, db.types.rootType(db.typechecker.typeOf(expr, query))), comment ) } diff --git a/src/main/scala/mimir/lenses/MissingKeyLens.scala b/src/main/scala/mimir/lenses/MissingKeyLens.scala index fde3b09f..f27884d8 100644 --- a/src/main/scala/mimir/lenses/MissingKeyLens.scala +++ b/src/main/scala/mimir/lenses/MissingKeyLens.scala @@ -19,11 +19,11 @@ object MissingKeyLens { args:Seq[Expression] ): (Operator, Seq[Model]) = { - val schema = db.typechecker.schemaOf(query) + val schema = db.typechecker.baseSchemaOf(query) val schemaMap = schema.toMap var missingOnly = false; var sortCols = Seq[(String, Boolean)]() - val keys: Seq[(String, Type)] = args.flatMap { + val keys: Seq[(String, BaseType)] = args.flatMap { case Var(col) => { if(schemaMap contains col){ Some((col, schemaMap(col))) } else { diff --git a/src/main/scala/mimir/lenses/PickerLens.scala b/src/main/scala/mimir/lenses/PickerLens.scala index 51c19d39..760e7ab2 100644 --- a/src/main/scala/mimir/lenses/PickerLens.scala +++ b/src/main/scala/mimir/lenses/PickerLens.scala @@ -31,7 +31,7 @@ object PickerLens { val (pickFromColumns, pickerColTypes ) = args.flatMap { case Function("PICK_FROM", cols ) => - Some( cols.map { case col:Var => (col.name, schemaMap(col.name)) + Some( cols.map { case col:Var => (col.name, db.types.rootType(schemaMap(col.name))) case col => throw new RAException(s"Invalid pick_from argument: $col in PickerLens $name (not a column reference)") } ) case _ => None diff --git a/src/main/scala/mimir/lenses/RepairKeyLens.scala b/src/main/scala/mimir/lenses/RepairKeyLens.scala index e297ef3f..9ee3ed8b 100644 --- a/src/main/scala/mimir/lenses/RepairKeyLens.scala +++ b/src/main/scala/mimir/lenses/RepairKeyLens.scala @@ -56,8 +56,8 @@ object RepairKeyLens extends LazyLogging { s"$name:$col", name, query, - keys.map { k => (k, schemaMap(k)) }, - col, t, + keys.map { k => (k, db.types.rootType(schemaMap(k))) }, + col, db.types.rootType(t), scoreCol ) model.trainDomain(db)//.reconnectToDatabase(db) diff --git a/src/main/scala/mimir/ml/spark/Classification.scala b/src/main/scala/mimir/ml/spark/Classification.scala index b1e1328b..d1567625 100644 --- a/src/main/scala/mimir/ml/spark/Classification.scala +++ b/src/main/scala/mimir/ml/spark/Classification.scala @@ -48,8 +48,8 @@ object Classification extends SparkML { applyModelDB(model, query, db) } - def classify( model : PipelineModel, cols:Seq[(String, Type)], testData : List[Seq[PrimitiveValue]]): DataFrame = { - applyModel(model, cols, testData) + def classify( model : PipelineModel, cols:Seq[(String, BaseType)], testData : List[Seq[PrimitiveValue]], sparkTranslator: OperatorTranslation): DataFrame = { + applyModel(model, cols, testData, sparkTranslator) } override def extractPredictions(model : PipelineModel, predictions:DataFrame, maxPredictions:Int = 5) : Seq[(String, (String, Double))] = { diff --git a/src/main/scala/mimir/ml/spark/Regression.scala b/src/main/scala/mimir/ml/spark/Regression.scala index 8e384306..6ce97aff 100644 --- a/src/main/scala/mimir/ml/spark/Regression.scala +++ b/src/main/scala/mimir/ml/spark/Regression.scala @@ -26,8 +26,8 @@ object Regression extends SparkML { applyModelDB(model, query, db) } - def regress( model : PipelineModel, cols:Seq[(String, Type)], testData : List[Seq[PrimitiveValue]]): DataFrame = { - applyModel(model, cols, testData) + def regress( model : PipelineModel, cols:Seq[(String, BaseType)], testData : List[Seq[PrimitiveValue]], sparkTranslator: OperatorTranslation): DataFrame = { + applyModel(model, cols, testData, sparkTranslator) } override def extractPredictions(model : PipelineModel, predictions:DataFrame, maxPredictions:Int = 5) : Seq[(String, (String, Double))] = { diff --git a/src/main/scala/mimir/ml/spark/SparkML.scala b/src/main/scala/mimir/ml/spark/SparkML.scala index eb952479..8fb809cf 100644 --- a/src/main/scala/mimir/ml/spark/SparkML.scala +++ b/src/main/scala/mimir/ml/spark/SparkML.scala @@ -32,10 +32,10 @@ object SparkML { sc = Some(spark.sparkSession.sparkContext) sqlCtx = Some(spark) } - def getDataFrameWithProvFromQuery(db:Database, query:Operator) : (Seq[(String, Type)], DataFrame) = { + def getDataFrameWithProvFromQuery(db:Database, query:Operator) : (Seq[(String, BaseType)], DataFrame) = { val prov = if(ExperimentalOptions.isEnabled("GPROM-PROVENANCE") && ExperimentalOptions.isEnabled("GPROM-BACKEND")) - { Provenance.compileGProM(query) } + { db.gpromTranslator.compileProvenanceWithGProM(query) } else { Provenance.compile(query) } val oper = prov._1 val provenanceCols = prov._2 @@ -46,7 +46,7 @@ object SparkML { val dfOut = dfPreOut.schema.fields.filter(col => Seq(DateType, TimestampType).contains(col.dataType)).foldLeft(dfPreOut)((init, cur) => init.withColumn(cur.name,init(cur.name).cast(LongType)) ) (db.typechecker.schemaOf(operWProv).map(el => el._2 match { case TDate() | TTimestamp() => (el._1, TInt()) - case _ => el + case _ => (el._1, db.types.rootType(el._2)) }), dfOut) } } @@ -106,16 +106,22 @@ abstract class SparkML { val data = db.query(query)(results => { results.toList.map(row => row.provenance +: row.tupleSchema.zip(row.tuple).filterNot(_._1._1.equalsIgnoreCase("rowid")).unzip._2) }) - applyModel(model, ("rowid", TString()) +:db.typechecker.schemaOf(query).filterNot(_._1.equalsIgnoreCase("rowid")), data, dfTransformer) + applyModel( + model, + ("rowid", TString()) +:db.typechecker.baseSchemaOf(query).filterNot(_._1.equalsIgnoreCase("rowid")), + data, + db.sparkTranslator, + dfTransformer + ) } - def applyModel( model : PipelineModel, cols:Seq[(String, Type)], testData : List[Seq[PrimitiveValue]], dfTransformer:Option[DataFrameTransformer] = None): DataFrame = { + def applyModel( model : PipelineModel, cols:Seq[(String, BaseType)], testData : List[Seq[PrimitiveValue]], sparkTranslator: OperatorTranslation, dfTransformer:Option[DataFrameTransformer] = None): DataFrame = { val sqlContext = getSparkSqlContext() import sqlContext.implicits._ val modDF = dfTransformer.getOrElse((df:DataFrame) => df) model.transform(modDF(sqlContext.createDataFrame( getSparkSession().parallelize(testData.map( row => { - Row(row.zip(cols).map(value => OperatorTranslation.mimirExprToSparkExpr(null, value._1)):_*) + Row(row.zip(cols).map(value => sparkTranslator.mimirExprToSparkExpr(null, value._1)):_*) })), StructType(cols.toList.map(col => StructField(col._1, OperatorTranslation.getSparkType(col._2), true)))))) } @@ -127,7 +133,7 @@ abstract class SparkML { def extractPredictionsForRow(model : PipelineModel, predictions:DataFrame, rowid:String, maxPredictions:Int = 5) : Seq[(String, Double)] - def getNative(value:PrimitiveValue, t:Type): Any = { + def getNative(value:PrimitiveValue, t:BaseType): Any = { value match { case NullPrimitive() => t match { case TInt() => 0L @@ -136,12 +142,10 @@ abstract class SparkML { case TString() => "" case TBool() => new java.lang.Boolean(false) case TRowId() => "" - case TType() => "" + case TType() => "any" case TAny() => "" case TTimestamp() => OperatorTranslation.defaultTimestamp - case TInterval() => "" - case TUser(name) => getNative(value, mimir.algebra.TypeRegistry.registeredTypes(name)._2) - case x => "" + case TInterval() => OperatorTranslation.defaultInterval } case RowIdPrimitive(s) => s case StringPrimitive(s) => s @@ -150,7 +154,8 @@ abstract class SparkML { case BoolPrimitive(b) => new java.lang.Boolean(b) case ts@TimestampPrimitive(y,m,d,h,mm,s,ms) => SparkUtils.convertTimestamp(ts) case dt@DatePrimitive(y,m,d) => SparkUtils.convertDate(dt) - case x => x.asString + case IntervalPrimitive(period) => period + case TypePrimitive(t) => t } } } diff --git a/src/main/scala/mimir/models/BasicModels.scala b/src/main/scala/mimir/models/BasicModels.scala index ed7c0616..df6a9075 100644 --- a/src/main/scala/mimir/models/BasicModels.scala +++ b/src/main/scala/mimir/models/BasicModels.scala @@ -7,7 +7,7 @@ import scala.util._ object UniformDistribution extends Model("UNIFORM") with Serializable { def argTypes(idx: Int) = List(TFloat(), TFloat()) def hintTypes(idx: Int) = Seq() - def varType(idx: Int, argTypes: Seq[Type]) = TFloat() + def varType(idx: Int, argTypes: Seq[BaseType]) = TFloat() def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]) = FloatPrimitive((args(0).asDouble + args(1).asDouble) / 2.0) def sample(idx: Int, randomness: Random, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]) = { @@ -53,7 +53,7 @@ case class NoOpModel(override val name: String, reasonText:String) def argTypes(idx: Int) = List(TAny()) def hintTypes(idx: Int) = Seq() - def varType(idx: Int, args: Seq[Type]) = args(0) + def varType(idx: Int, args: Seq[BaseType]) = args(0) def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]) = args(0) def sample(idx: Int, randomness: Random, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]) = args(0) def reason(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): String = reasonText diff --git a/src/main/scala/mimir/models/CommentModel.scala b/src/main/scala/mimir/models/CommentModel.scala index 7d43f462..2b65162d 100644 --- a/src/main/scala/mimir/models/CommentModel.scala +++ b/src/main/scala/mimir/models/CommentModel.scala @@ -14,7 +14,7 @@ import java.sql.SQLException * The return value is an integer identifying the ordinal position of the selected value, starting with 0. */ @SerialVersionUID(1001L) -class CommentModel(override val name: String, cols:Seq[String], colTypes:Seq[Type], comments:Seq[String]) +class CommentModel(override val name: String, cols:Seq[String], colTypes:Seq[BaseType], comments:Seq[String]) extends Model(name) with Serializable with SourcedFeedback @@ -23,7 +23,7 @@ class CommentModel(override val name: String, cols:Seq[String], colTypes:Seq[Typ def getFeedbackKey(idx: Int, args: Seq[PrimitiveValue] ) : String = s"${args(0).asString}:$idx" def argTypes(idx: Int) = Seq(TRowId()) - def varType(idx: Int, args: Seq[Type]) = colTypes(idx) + def varType(idx: Int, args: Seq[BaseType]) = colTypes(idx) def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue] ) = { getFeedback(idx, args) match { case Some(v) => v @@ -48,7 +48,7 @@ class CommentModel(override val name: String, cols:Seq[String], colTypes:Seq[Typ setFeedback(idx, args, v) } def isAcknowledged (idx: Int, args: Seq[PrimitiveValue]): Boolean = hasFeedback(idx, args) - def hintTypes(idx: Int): Seq[mimir.algebra.Type] = colTypes + def hintTypes(idx: Int): Seq[BaseType] = colTypes //def getDomain(idx: Int, args: Seq[PrimitiveValue], hints:Seq[PrimitiveValue]): Seq[(PrimitiveValue,Double)] = Seq((hints(0), 0.0)) def confidence (idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): Double = { diff --git a/src/main/scala/mimir/models/DefaultMetaModel.scala b/src/main/scala/mimir/models/DefaultMetaModel.scala index aa9eb740..514cc65b 100644 --- a/src/main/scala/mimir/models/DefaultMetaModel.scala +++ b/src/main/scala/mimir/models/DefaultMetaModel.scala @@ -18,7 +18,7 @@ class DefaultMetaModel(name: String, context: String, models: Seq[String]) with NoArgModel with FiniteDiscreteDomain { - def varType(idx: Int, args: Seq[Type]): Type = TString() + def varType(idx: Int, args: Seq[BaseType]): BaseType = TString() def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): PrimitiveValue = choices(idx).getOrElse( StringPrimitive(models.head)) def sample(idx: Int, randomness: Random, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): PrimitiveValue = diff --git a/src/main/scala/mimir/models/DetectHeaderModel.scala b/src/main/scala/mimir/models/DetectHeaderModel.scala index c45fd396..62fe676c 100644 --- a/src/main/scala/mimir/models/DetectHeaderModel.scala +++ b/src/main/scala/mimir/models/DetectHeaderModel.scala @@ -52,6 +52,9 @@ with SourcedFeedback } else { + // TODO: COMMENT THIS CODE + // From Oliver 12/16/2018: + // This code is really hard to follow. It needs some serious commenting love. val top6 = trainingData val (header, topRecords) = (top6.head.map(col => sanitizeColumnName(col match { case NullPrimitive() => "NULL" @@ -62,10 +65,10 @@ with SourcedFeedback (pv._1 match { case NullPrimitive() => TAny() case x => { - Type.rootTypes.foldLeft(TAny():Type)((tinit, ttype) => { - Cast.apply(ttype,x) match { + BaseType.tests.foldLeft(TAny():Type)((tinit, ttype) => { + Cast(ttype._1,x) match { case NullPrimitive() => tinit - case x => ttype + case x => ttype._1 } }) } @@ -118,7 +121,7 @@ with SourcedFeedback def argTypes(idx: Int) = { Seq(TInt()) } - def varType(idx: Int, args: Seq[Type]) = { + def varType(idx: Int, args: Seq[BaseType]) = { TString() } def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue] ) = { @@ -143,7 +146,7 @@ with SourcedFeedback def isAcknowledged (idx: Int, args: Seq[PrimitiveValue]): Boolean = { hasFeedback(idx, args) } - def hintTypes(idx: Int): Seq[mimir.algebra.Type] = { + def hintTypes(idx: Int): Seq[BaseType] = { Seq() } def getFeedbackKey(idx: Int, args: Seq[PrimitiveValue]) = { diff --git a/src/main/scala/mimir/models/EditDistanceMatchModel.scala b/src/main/scala/mimir/models/EditDistanceMatchModel.scala index 70247f3a..e73894e1 100644 --- a/src/main/scala/mimir/models/EditDistanceMatchModel.scala +++ b/src/main/scala/mimir/models/EditDistanceMatchModel.scala @@ -30,15 +30,15 @@ object EditDistanceMatchModel def train( db: Database, name: String, - source: Either[Operator,Seq[(String,Type)]], - target: Either[Operator,Seq[(String,Type)]] + source: Either[Operator,Seq[(String,BaseType)]], + target: Either[Operator,Seq[(String,BaseType)]] ): Map[String,(Model,Int)] = { - val sourceSch: Seq[(String,Type)] = source match { - case Left(oper) => db.typechecker.schemaOf(oper) + val sourceSch: Seq[(String,BaseType)] = source match { + case Left(oper) => db.typechecker.baseSchemaOf(oper) case Right(sch) => sch } - val targetSch: Seq[(String,Type)] = target match { - case Left(oper) => db.typechecker.schemaOf(oper) + val targetSch: Seq[(String,BaseType)] = target match { + case Left(oper) => db.typechecker.baseSchemaOf(oper) case Right(sch) => sch } targetSch.map({ case (targetCol,targetType) => @@ -47,22 +47,20 @@ object EditDistanceMatchModel s"$name:$targetCol", defaultMetric, (targetCol, targetType), - sourceSch. - filter((x) => isTypeCompatible(targetType, x._2)). - map( _._1 ) + sourceSch + .filter((x) => isTypeCompatible(db.types.rootType(targetType), db.types.rootType(x._2))) + .map( _._1 ) ), 0)) }).toMap } - def isTypeCompatible(a: Type, b: Type): Boolean = + def isTypeCompatible(a: BaseType, b: BaseType): Boolean = { - val aBase = Type.rootType(a) - val bBase = Type.rootType(b) - (aBase, bBase) match { + (a, b) match { case ((TInt()|TFloat()), (TInt()|TFloat())) => true case (TAny(), _) => true case (_, TAny()) => true - case _ => aBase == bBase + case _ => a == b } } @@ -72,7 +70,7 @@ object EditDistanceMatchModel class EditDistanceMatchModel( name: String, metricName: String, - target: (String, Type), + target: (String, BaseType), sourceCandidates: Seq[String] ) extends Model(name) @@ -101,7 +99,7 @@ class EditDistanceMatchModel( map({ case (k, v) => (k, v.toDouble) }). toIndexedSeq } - def varType(idx: Int, t: Seq[Type]) = TString() + def varType(idx: Int, t: Seq[BaseType]) = TString() def sample(idx: Int, randomness: Random, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): PrimitiveValue = { diff --git a/src/main/scala/mimir/models/FuncDepModel.scala b/src/main/scala/mimir/models/FuncDepModel.scala index 908f0e3f..4d302fdb 100644 --- a/src/main/scala/mimir/models/FuncDepModel.scala +++ b/src/main/scala/mimir/models/FuncDepModel.scala @@ -34,7 +34,13 @@ object FuncDepModel db.models.getOption(modelName) match { case Some(model) => model case None => { - val model = new SimpleFuncDepModel(modelName, QueryNamer(query), col, db.typechecker.schemaOf(query)) + val model = new SimpleFuncDepModel( + modelName, + QueryNamer(query), + col, + db.typechecker.schemaOf(query) + .map { case (col, t) => (col, db.types.rootType(t)) } + ) sourceCol = trainModel(db, query, model) db.models.persist(model) model @@ -81,7 +87,7 @@ object FuncDepModel } @SerialVersionUID(1001L) -class SimpleFuncDepModel(name: String, val tableName:String, val colName: String, val schema:Seq[(String,Type)]) +class SimpleFuncDepModel(name: String, val tableName:String, val colName: String, val schema:Seq[(String,BaseType)]) extends Model(name) { val colIdx:Int = schema.map{_._1}.indexOf(colName) @@ -105,13 +111,13 @@ class SimpleFuncDepModel(name: String, val tableName:String, val colName: String def isAcknowledged(idx: Int, args: Seq[PrimitiveValue]): Boolean = feedback contains(args(0).asString) - def guessInputType: Type = + def guessInputType: BaseType = schema(colIdx)._2 - def argTypes(idx: Int): Seq[Type] = List(TRowId()) + def argTypes(idx: Int): Seq[BaseType] = List(TRowId()) def hintTypes(idx: Int) = schema.map(_._2) - def varType(idx: Int, args: Seq[Type]): Type = guessInputType + def varType(idx: Int, args: Seq[BaseType]): BaseType = guessInputType def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): PrimitiveValue = { diff --git a/src/main/scala/mimir/models/GeocodingModel.scala b/src/main/scala/mimir/models/GeocodingModel.scala index a17b23c4..467ad7ec 100644 --- a/src/main/scala/mimir/models/GeocodingModel.scala +++ b/src/main/scala/mimir/models/GeocodingModel.scala @@ -43,7 +43,7 @@ class GeocodingModel(override val name: String, addrCols:Seq[Expression], geocod def argTypes(idx: Int) = { Seq(TRowId()).union(addrCols.map(_ => TString())) } - def varType(idx: Int, args: Seq[Type]) = TFloat() + def varType(idx: Int, args: Seq[BaseType]) = TFloat() def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue] ) = { val rowid = RowIdPrimitive(args(0).asString) getFeedback(idx, args) match { @@ -91,7 +91,7 @@ class GeocodingModel(override val name: String, addrCols:Seq[Expression], geocod def isAcknowledged (idx: Int, args: Seq[PrimitiveValue]): Boolean = { hasFeedback(idx, args) } - def hintTypes(idx: Int): Seq[mimir.algebra.Type] = Seq() + def hintTypes(idx: Int): Seq[BaseType] = Seq() def getDomain(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): Seq[(PrimitiveValue,Double)] = { diff --git a/src/main/scala/mimir/models/MissingKeyModel.scala b/src/main/scala/mimir/models/MissingKeyModel.scala index fab5ac74..2aa5a221 100644 --- a/src/main/scala/mimir/models/MissingKeyModel.scala +++ b/src/main/scala/mimir/models/MissingKeyModel.scala @@ -13,7 +13,7 @@ import mimir.util._ * The return value is an integer identifying the ordinal position of the selected value, starting with 0. */ @SerialVersionUID(1001L) -class MissingKeyModel(override val name: String, keys:Seq[String], colTypes:Seq[Type]) +class MissingKeyModel(override val name: String, keys:Seq[String], colTypes:Seq[BaseType]) extends Model(name) with Serializable with FiniteDiscreteDomain @@ -25,7 +25,7 @@ class MissingKeyModel(override val name: String, keys:Seq[String], colTypes:Seq[ def argTypes(idx: Int) = { Seq(TRowId()) } - def varType(idx: Int, args: Seq[Type]) = colTypes(idx) + def varType(idx: Int, args: Seq[BaseType]) = colTypes(idx) def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue] ) = { //println(s"MissingKeyModel:bestGuess: idx: $idx args: ${args.mkString("[ ",","," ]")} hints: ${hints.mkString("[ ",","," ]")}") getFeedback(idx, args) match { @@ -63,7 +63,7 @@ class MissingKeyModel(override val name: String, keys:Seq[String], colTypes:Seq[ def isAcknowledged (idx: Int, args: Seq[PrimitiveValue]): Boolean = { hasFeedback(idx, args) } - def hintTypes(idx: Int): Seq[mimir.algebra.Type] = Seq(TAny()) + def hintTypes(idx: Int): Seq[BaseType] = Seq(TAny()) def getDomain(idx: Int, args: Seq[PrimitiveValue], hints:Seq[PrimitiveValue]): Seq[(PrimitiveValue,Double)] = Seq((hints(0), 0.0)) def confidence (idx: Int, args: Seq[PrimitiveValue], hints:Seq[PrimitiveValue]) : Double = { diff --git a/src/main/scala/mimir/models/Model.scala b/src/main/scala/mimir/models/Model.scala index 3fbac3bf..77aad890 100644 --- a/src/main/scala/mimir/models/Model.scala +++ b/src/main/scala/mimir/models/Model.scala @@ -51,18 +51,18 @@ abstract class Model(val name: String) extends Serializable { /** * The list of expected arg types (may be TAny) */ - def argTypes (idx: Int): Seq[Type] + def argTypes (idx: Int): Seq[BaseType] /** * The list of expected hint types (may be TAny) */ - def hintTypes (idx: Int): Seq[Type] + def hintTypes (idx: Int): Seq[BaseType] /** * Infer the type of the model from the types of the inputs * @param argTypes The types of the arguments the the VGTerm * @return The type of the value returned by this model */ - def varType (idx: Int, argTypes: Seq[Type]): Type + def varType (idx: Int, argTypes: Seq[BaseType]): BaseType /** * Generate a best guess for a variable represented by this model. diff --git a/src/main/scala/mimir/models/ModelBuildingBlocks.scala b/src/main/scala/mimir/models/ModelBuildingBlocks.scala index aba7da36..977629c1 100644 --- a/src/main/scala/mimir/models/ModelBuildingBlocks.scala +++ b/src/main/scala/mimir/models/ModelBuildingBlocks.scala @@ -4,7 +4,7 @@ import mimir.algebra._ trait NoArgModel { - def argTypes(x: Int): Seq[Type] = List() + def argTypes(x: Int): Seq[BaseType] = List() def hintTypes(idx: Int) = Seq() } diff --git a/src/main/scala/mimir/models/ModelRegistry.scala b/src/main/scala/mimir/models/ModelRegistry.scala index b54efc24..7b58a867 100644 --- a/src/main/scala/mimir/models/ModelRegistry.scala +++ b/src/main/scala/mimir/models/ModelRegistry.scala @@ -85,8 +85,8 @@ object ModelRegistry type SchemaMatchConstructor = ((Database, String, - Either[Operator,Seq[(String,Type)]], - Either[Operator,Seq[(String,Type)]]) => + Either[Operator,Seq[(String,BaseType)]], + Either[Operator,Seq[(String,BaseType)]]) => Map[String,(Model,Int)]) /////////////////// PREDEFINED CONSTRUCTORS /////////////////// diff --git a/src/main/scala/mimir/models/PickerModel.scala b/src/main/scala/mimir/models/PickerModel.scala index 668f5a05..e201c0fc 100644 --- a/src/main/scala/mimir/models/PickerModel.scala +++ b/src/main/scala/mimir/models/PickerModel.scala @@ -26,7 +26,7 @@ object PickerModel def availableSparkModels(trainingDataq:DataFrame) = Map("Classification" -> (Classification, Classification.DecisionTreeMulticlassModel(trainingDataq)), "Regression" -> (Regression, Regression.GeneralizedLinearRegressorModel(trainingDataq))) - def train(db:Database, name: String, resultColumn:String, pickFromCols:Seq[String], colTypes:Seq[Type], useClassifier:Option[String], classifyUpFrontAndCache:Boolean, query: Operator ) : SimplePickerModel = { + def train(db:Database, name: String, resultColumn:String, pickFromCols:Seq[String], colTypes:Seq[BaseType], useClassifier:Option[String], classifyUpFrontAndCache:Boolean, query: Operator ) : SimplePickerModel = { val pickerModel = new SimplePickerModel(name, resultColumn, pickFromCols, colTypes, useClassifier, classifyUpFrontAndCache, query) val trainingQuery = Limit(0, Some(TRAINING_LIMIT), Sort(Seq(SortColumn(Function("random", Seq()), true)), Project(pickFromCols.map(col => ProjectArg(col, Var(col))), query.filter(Not(IsNullExpression(Var(pickFromCols.head)))) ))) val (schemao, trainingDatao) = SparkML.getDataFrameWithProvFromQuery(db, trainingQuery) @@ -51,7 +51,7 @@ object PickerModel * The return value is an integer identifying the ordinal position of the selected value, starting with 0. */ @SerialVersionUID(1002L) -class SimplePickerModel(override val name: String, resultColumn:String, pickFromCols:Seq[String], colTypes:Seq[Type], useClassifier:Option[String], classifyUpFrontAndCache:Boolean, source: Operator) +class SimplePickerModel(override val name: String, resultColumn:String, pickFromCols:Seq[String], colTypes:Seq[BaseType], useClassifier:Option[String], classifyUpFrontAndCache:Boolean, source: Operator) extends Model(name) with Serializable with FiniteDiscreteDomain @@ -61,7 +61,7 @@ class SimplePickerModel(override val name: String, resultColumn:String, pickFrom var classifierModel: Option[PipelineModel] = None var classifyAllPredictions:Option[Map[String, Seq[(String, Double)]]] = None - var schema:Seq[(String, Type)] = null + var schema:Seq[(String, BaseType)] = null var trainingData:DataFrame = null @@ -114,9 +114,9 @@ class SimplePickerModel(override val name: String, resultColumn:String, pickFrom } def argTypes(idx: Int) = { - Seq(TRowId()).union(colTypes) + Seq(TRowId()) ++ colTypes } - def varType(idx: Int, args: Seq[Type]) = colTypes(idx) + def varType(idx: Int, args: Seq[BaseType]) = colTypes(idx) def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue] ) = { val rowid = args(0).asString getFeedback(idx, args) match { @@ -193,7 +193,7 @@ class SimplePickerModel(override val name: String, resultColumn:String, pickFrom def isAcknowledged (idx: Int, args: Seq[PrimitiveValue]): Boolean = { hasFeedback(idx, args) } - def hintTypes(idx: Int): Seq[mimir.algebra.Type] = useClassifier match { + def hintTypes(idx: Int): Seq[BaseType] = useClassifier match { case None => Seq(TAny()) case Some(x) => Seq() } diff --git a/src/main/scala/mimir/models/RepairKeyModel.scala b/src/main/scala/mimir/models/RepairKeyModel.scala index bb8ddde4..aabcda35 100644 --- a/src/main/scala/mimir/models/RepairKeyModel.scala +++ b/src/main/scala/mimir/models/RepairKeyModel.scala @@ -23,9 +23,9 @@ class RepairKeyModel( name: String, context: String, source: Operator, - keys: Seq[(String, Type)], + keys: Seq[(String, BaseType)], target: String, - targetType: Type, + targetType: BaseType, scoreCol: Option[String] ) extends Model(name) @@ -37,7 +37,7 @@ class RepairKeyModel( def getFeedbackKey(idx: Int, args: Seq[PrimitiveValue]) : List[PrimitiveValue] = args.toList - def varType(idx: Int, args: Seq[Type]): Type = targetType + def varType(idx: Int, args: Seq[BaseType]): BaseType = targetType def argTypes(idx: Int) = keys.map(_._2) def hintTypes(idx: Int) = Seq(TString(), TString()) @@ -78,7 +78,7 @@ class RepairKeyModel( source ) ) - val mopSchema = db.typechecker.schemaOf(mop) + val mopSchema = db.typechecker.baseSchemaOf(mop) domainCache = db.backend.execute(mop).map( row => { ( SparkUtils.convertField(mopSchema(0)._2, row, 0, TString()), scoreCol match { @@ -139,7 +139,7 @@ class RepairKeyModel( case TString() => StringPrimitive(value.asInstanceOf[String]) case TBool() => BoolPrimitive(value.asInstanceOf[Boolean]) case TRowId() => RowIdPrimitive(value.asInstanceOf[String]) - case TType() => TypePrimitive(Type.fromString(value.asInstanceOf[String])) + // case TType() => TypePrimitive(BaseType.fromString(value.asInstanceOf[String])) //case TAny() => NullPrimitive() //case TUser(name) => name.toLowerCase //case TInterval() => Primitive(value.asInstanceOf[Long]) diff --git a/src/main/scala/mimir/models/SeriesMissingValueModel.scala b/src/main/scala/mimir/models/SeriesMissingValueModel.scala index c32e39ba..d38dbd7e 100644 --- a/src/main/scala/mimir/models/SeriesMissingValueModel.scala +++ b/src/main/scala/mimir/models/SeriesMissingValueModel.scala @@ -68,11 +68,11 @@ class SimpleSeriesModel(name: String, colNames: Seq[String], query: Operator) var predictions:Seq[Dataset[(String, Double)]] = Seq() var queryDf: DataFrame = null - var rowIdType:Type = TString() - var dateType:Type = TDate() + var rowIdType:BaseType = TString() + var dateType:BaseType = TDate() //colNames.map { _ => Seq[(String,Double)]() } var queryCols:Seq[String] = colNames - var querySchema:Seq[(String,Type)] = + var querySchema:Seq[(String,BaseType)] = colNames.map { x => (x, TAny()) } def getCacheKey(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue] ) : String = s"${idx}_${args(0).asString}_${hints(0).asString}" @@ -80,7 +80,7 @@ class SimpleSeriesModel(name: String, colNames: Seq[String], query: Operator) def train(db: Database): Seq[Boolean] = { - querySchema = db.typechecker.schemaOf(query) + querySchema = db.typechecker.baseSchemaOf(query) queryCols = querySchema.unzip._1 queryDf = db.backend.execute(query) rowIdType = db.backend.rowIdType @@ -158,7 +158,7 @@ class SimpleSeriesModel(name: String, colNames: Seq[String], query: Operator) def argTypes(idx: Int) = Seq(TRowId()) - def varType(idx: Int, args: Seq[Type]): Type = querySchema(idx)._2 + def varType(idx: Int, args: Seq[BaseType]): BaseType = querySchema(idx)._2 def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue] ): PrimitiveValue = { @@ -224,7 +224,7 @@ class SimpleSeriesModel(name: String, colNames: Seq[String], query: Operator) def isAcknowledged(idx: Int, args: Seq[PrimitiveValue]): Boolean = hasFeedback(idx, args) - def hintTypes(idx: Int): Seq[mimir.algebra.Type] = Seq() + def hintTypes(idx: Int): Seq[BaseType] = Seq() def confidence (idx: Int, args: Seq[PrimitiveValue], hints:Seq[PrimitiveValue]) : Double = { val df = predictions(idx) diff --git a/src/main/scala/mimir/models/SparkClassifierModel.scala b/src/main/scala/mimir/models/SparkClassifierModel.scala index ce75d3a7..13e7400e 100644 --- a/src/main/scala/mimir/models/SparkClassifierModel.scala +++ b/src/main/scala/mimir/models/SparkClassifierModel.scala @@ -84,7 +84,7 @@ object SparkClassifierModel } @SerialVersionUID(1001L) -class SimpleSparkClassifierModel(name: String, val colName:String, val schema:Seq[(String, Type)], val data:DataFrame) +class SimpleSparkClassifierModel(name: String, val colName:String, val schema:Seq[(String, BaseType)], val data:DataFrame) extends Model(name) with SourcedFeedback with ModelCache @@ -102,11 +102,10 @@ class SimpleSparkClassifierModel(name: String, val colName:String, val schema:Se - def guessSparkModelType(t:Type) : String = { + def guessSparkModelType(t:BaseType) : String = { t match { case TFloat() if columns.length == 1 => "Regression" case TInt() | TDate() | TString() | TBool() | TRowId() | TType() | TAny() | TTimestamp() | TInterval() => "Classification" - case TUser(name) => guessSparkModelType(mimir.algebra.TypeRegistry.registeredTypes(name)._2) case x => "Classification" } } @@ -153,13 +152,13 @@ class SimpleSparkClassifierModel(name: String, val colName:String, val schema:Se } } - def guessInputType: Type = + def guessInputType: BaseType = schema(colIdx)._2 - def argTypes(idx: Int): Seq[Type] = List(TRowId()) + def argTypes(idx: Int): Seq[BaseType] = List(TRowId()) def hintTypes(idx: Int) = schema.reverse.tail.reverse.map(_._2) - def varType(idx: Int, args: Seq[Type]): Type = guessInputType + def varType(idx: Int, args: Seq[BaseType]): BaseType = guessInputType def bestGuess(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): PrimitiveValue = { diff --git a/src/main/scala/mimir/models/TypeInferenceModel.scala b/src/main/scala/mimir/models/TypeInferenceModel.scala index 91c1e91a..f2c6e168 100644 --- a/src/main/scala/mimir/models/TypeInferenceModel.scala +++ b/src/main/scala/mimir/models/TypeInferenceModel.scala @@ -5,6 +5,7 @@ import com.typesafe.scalalogging.slf4j.Logger import mimir.Database import mimir.algebra._ +import mimir.algebra.typeregistry.TypeRegistry import mimir.util._ import org.apache.spark.sql.{DataFrame, Row, Encoders, Encoder, Dataset} import mimir.ml.spark.SparkML @@ -35,37 +36,30 @@ object TypeInferenceModel case TAny() => -10 } - def detectType(v: String): Iterable[Type] = { - Type.tests.flatMap({ case (t, regexp) => - regexp.findFirstMatchIn(v).map(_ => t) - })++ - TypeRegistry.matchers.flatMap({ case (regexp, name) => - regexp.findFirstMatchIn(v).map(_ => TUser(name)) - }) - } } case class TIVotes(votes:Seq[Map[Int,Double]]) -case class VoteList() +case class VoteList(types:(TypeRegistry with Serializable)) extends Aggregator[Row, Seq[(Long,Seq[(Int,Long)])], Seq[(Long,Map[Int,(Long,Double)])]] with Serializable { def zero = Seq[(Long,Seq[(Int, Long)])]() def reduce(acc: Seq[(Long, Seq[(Int, Long)])], x: Row) = { - val newacc = x.toSeq.zipWithIndex.map(field => + val newacc:Seq[(Long, Seq[(Int, Long)])] = + x.toSeq.zipWithIndex.map(field => field match { case (null, idx) => (0L, Seq[(Int, Long)]()) case (_, idx) => { if(!x.isNullAt(idx)){ val cellVal = x.getString(idx) - (1L, TypeInferenceModel.detectType(cellVal).toSeq.map(el => (Type.id(el), 1L))) + (1L, types.testForTypes(cellVal).toSeq.map(el => (types.idForType(el):Int, 1L))) } else (0L, Seq[(Int, Long)]()) } } ) acc match { - case Seq() | Seq(0 Seq()) => newacc + case Seq() | Seq( (0, Seq()) ) => newacc case _ => { acc.zip(newacc).map(oldNew => { (oldNew._1._1+oldNew._2._1, oldNew._1._2++oldNew._2._2) @@ -96,21 +90,26 @@ case class VoteList() @SerialVersionUID(1002L) -class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultFrac: Double, sparkSql:SQLContext, query:Option[DataFrame] ) +class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultFrac: Double, sparkSql:SQLContext, query:Option[DataFrame], types:(TypeRegistry with Serializable)) extends Model(name) with SourcedFeedback with FiniteDiscreteDomain { - var trainingData:Seq[(Long, Map[Int,(Long,Double)])] = query match { - case Some(df) => train(df) - case None => columns.map(col => (0L,Map[Int,(Long,Double)]())) - } + // Sequence of + // Total # of Votes + // Map from TypeID -> (# of Votes for Type, % of Votes for Type) + // One sequence element per input column + var trainingData:Seq[(Long, Map[Int,(Long,Double)])] = + query match { + case Some(df) => train(df) + case None => columns.map(col => (0L,Map[Int,(Long,Double)]())) + } private def train(df:DataFrame) = { import sparkSql.implicits._ df.limit(TypeInferenceModel.sampleLimit).select(columns.map(col(_)):_*) - .agg(new VoteList().toColumn) + .agg(new VoteList(types).toColumn) .head() .asInstanceOf[Row].toSeq(0).asInstanceOf[Seq[Row]] .map(el => (el.getLong(0), el.getMap[Int,Row](1).toMap) ) @@ -119,28 +118,53 @@ class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultF final def learn(idx: Int, v: String):Unit = { - val newtypes = TypeInferenceModel.detectType(v).toSeq.map(tp => (Type.id(tp), 1L)) + val newtypes = types.testForTypes(v).toSeq.map(tp => (types.idForType(tp):Int, 1L)) val oldAcc = trainingData(idx) val (oldTotal, oldTypes) = (oldAcc._1, oldAcc._2.toSeq.map(el => (el._1, el._2._1))) val newTotalVotes = (1+ oldTotal).toLong - trainingData = trainingData.zipWithIndex.map( votesidx => if(votesidx._2 == idx) (newTotalVotes, (newtypes ++ oldTypes).groupBy(_._1).mapValues(el => { - val votesForType = el.map(_._2).sum.toLong - (votesForType, votesForType.toDouble/newTotalVotes.toDouble) - })) else votesidx._1) + trainingData = + trainingData.zipWithIndex.map { votesidx => + if(votesidx._2 == idx) { + ( newTotalVotes, + (newtypes ++ oldTypes).groupBy { _._1:Int } + .mapValues(el => { + val votesForType = el.map(_._2).sum.toLong + (votesForType, votesForType.toDouble/newTotalVotes.toDouble) + }) + .toMap[Int,(Long,Double)] + ) + } else { votesidx._1 } + } } - def voteList(idx:Int) = (Type.id(TString()) -> ((defaultFrac * totalVotes(idx)).toLong, defaultFrac)) :: (trainingData(idx)._2.map(votedType => (votedType._1 -> (votedType._2._1, votedType._2._2)))).toList + def voteList(idx:Int): Seq[(Integer, (Long,Double))] = + Seq( + ( (types.idForType(TString())) -> + ((defaultFrac * totalVotes(idx)).toLong, defaultFrac) + ) + ) ++ ( + trainingData(idx)._2.map { votedType => + ((votedType._1:Integer) -> (votedType._2._1, votedType._2._2)) + }.toSeq + ) def totalVotes(idx:Int) = trainingData(idx)._1 private final def rankFn(x:(Type, Double)) = (x._2, TypeInferenceModel.priority(x._1) ) - def varType(idx: Int, argTypes: Seq[Type]) = TType() + def varType(idx: Int, argTypes: Seq[BaseType]) = TType() def sample(idx: Int, randomness: Random, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): PrimitiveValue = { val column = args(0).asInt TypePrimitive( - Type.toSQLiteType(RandUtils.pickFromWeightedList(randomness, voteList(column).map(el => (el._1, el._2._1.toDouble)).toSeq)) + types.typeForId( + RandUtils.pickFromWeightedList( + randomness, + voteList(column) + .map{ el => (el._1, el._2._1.toDouble) } + .toSeq + ) + ) ) } @@ -149,7 +173,7 @@ class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultF val column = args(0).asInt getFeedback(idx, args) match { case None => { - val guess = voteList(column).map(tp => (Type.toSQLiteType(tp._1), tp._2._2)).maxBy( rankFn _ )._1 + val guess = voteList(column).map(tp => (types.typeForId(tp._1), tp._2._2)).maxBy( rankFn _ )._1 //println(s"bestGuess(idx: $idx, args: ${args.mkString(",")}, hints:${hints.mkString(",")}) => $guess") TypePrimitive(guess) } @@ -164,10 +188,10 @@ class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultF val column = args(0).asInt getFeedback(idx, args) match { case None => { - val (guess, guessFrac) = voteList(column).map(tp => (Type.toSQLiteType(tp._1), tp._2._2)).maxBy( rankFn _ ) + val (guess, guessFrac) = voteList(column).map(tp => (types.typeForId(tp._1), tp._2._2)).maxBy( rankFn _ ) val defaultPct = (defaultFrac * 100).toInt val guessPct = (guessFrac*100).toInt - val typeStr = Type.toString(guess).toUpperCase + val typeStr = guess.toString.toUpperCase val reason = guess match { case TString() => @@ -188,7 +212,7 @@ class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultF def getDomain(idx: Int, args: Seq[PrimitiveValue], hints: Seq[PrimitiveValue]): Seq[(PrimitiveValue,Double)] = { val column = args(0).asInt - trainingData(idx)._2.map( x => (TypePrimitive(Type.toSQLiteType(x._1)), x._2._2)).toSeq ++ Seq( (TypePrimitive(TString()), defaultFrac) ) + trainingData(idx)._2.map( x => (TypePrimitive(types.typeForId(x._1)), x._2._2)).toSeq ++ Seq( (TypePrimitive(TString()), defaultFrac) ) } def feedback(idx: Int, args: Seq[PrimitiveValue], v: PrimitiveValue): Unit = @@ -206,9 +230,9 @@ class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultF voteList(column).map( _._2._1 ).max >= totalVotes(column).toDouble def getFeedbackKey(idx: Int, args: Seq[PrimitiveValue]): String = args(0).asString - def argTypes(idx: Int): Seq[Type] = + def argTypes(idx: Int): Seq[BaseType] = Seq(TInt()) - def hintTypes(idx: Int): Seq[Type] = + def hintTypes(idx: Int): Seq[BaseType] = Seq() @@ -216,10 +240,9 @@ class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultF val column = args(0).asInt getFeedback(idx, args) match { case None => { - val (guess, guessFrac) = voteList(column).map(tp => (Type.toSQLiteType(tp._1), tp._2._2)).maxBy( rankFn _ ) + val (guess, guessFrac) = voteList(column).map(tp => (types.typeForId(tp._1), tp._2._2)).maxBy( rankFn _ ) val defaultPct = (defaultFrac * 100).toInt val guessPct = (guessFrac*100).toInt - val typeStr = Type.toString(guess).toUpperCase if (guessPct > defaultPct) guessFrac else @@ -230,3 +253,30 @@ class TypeInferenceModel(name: String, val columns: IndexedSeq[String], defaultF } } + +class TypeInferenceCastFailedModel(name: String, t: Type) + extends Model(name) + with SourcedFeedback +{ + + def argTypes(idx: Int): Seq[BaseType] = Seq(TString(), TAny()) + def hintTypes(idx: Int): Seq[BaseType] = Seq() + def varType(idx: Int,argTypes: Seq[BaseType]): BaseType = argTypes(1) + def getFeedbackKey(idx: Int, args: Seq[PrimitiveValue]) = args(0).asString + + def bestGuess(idx: Int,args: Seq[PrimitiveValue],hints: Seq[PrimitiveValue]): PrimitiveValue = + getFeedback(idx, args) match { case Some(s) => s; case None => NullPrimitive() } + def confidence(idx: Int,args: Seq[PrimitiveValue],hints: Seq[PrimitiveValue]): Double = + if(getFeedback(idx, args) == None) { 0.0 } else { 1.0 } + def feedback(idx: Int,args: Seq[PrimitiveValue],v: PrimitiveValue): Unit = + setFeedback(idx, args, v) + def isAcknowledged(idx: Int,args: Seq[PrimitiveValue]): Boolean = + (getFeedback(idx, args) != None) + def reason(idx: Int,args: Seq[PrimitiveValue],hints: Seq[PrimitiveValue]): String = + s"${args(0)} is not a valid ${t}" + def sample(idx: Int,randomness: scala.util.Random,args: Seq[PrimitiveValue],hints: Seq[PrimitiveValue]): PrimitiveValue = + /// TODO: a completely random instance of the selected type would be a better bet. + return bestGuess(idx, args, hints) + + +} diff --git a/src/main/scala/mimir/optimizer/Optimizer.scala b/src/main/scala/mimir/optimizer/Optimizer.scala index e1cdaea9..175a8934 100644 --- a/src/main/scala/mimir/optimizer/Optimizer.scala +++ b/src/main/scala/mimir/optimizer/Optimizer.scala @@ -53,7 +53,4 @@ object Optimizer opts.foldLeft(e)( (currE, f) => f(currE) ) } - def gpromOptimize(rawOper: Operator): Operator = { - OperatorTranslation.optimizeWithGProM(rawOper) - } } \ No newline at end of file diff --git a/src/main/scala/mimir/optimizer/operator/PropagateEmptyViews.scala b/src/main/scala/mimir/optimizer/operator/PropagateEmptyViews.scala index c53499ef..2fc70f81 100644 --- a/src/main/scala/mimir/optimizer/operator/PropagateEmptyViews.scala +++ b/src/main/scala/mimir/optimizer/operator/PropagateEmptyViews.scala @@ -31,7 +31,7 @@ class PropagateEmptyViews(typechecker: Typechecker, aggregates: AggregateRegistr aggregates.typecheck( function, args.map { expr => - typechecker.typeOf(expr,schMap) + typechecker.rootType(typechecker.typeOf(expr,schMap)) } ) ), diff --git a/src/main/scala/mimir/parser/ExpressionParser.scala b/src/main/scala/mimir/parser/ExpressionParser.scala index 22e4e9c5..49cbbb12 100644 --- a/src/main/scala/mimir/parser/ExpressionParser.scala +++ b/src/main/scala/mimir/parser/ExpressionParser.scala @@ -142,7 +142,7 @@ object ExpressionParser extends RegexParsers { def exprType: Parser[Type] = ( "int" | "decimal" | "date" | "string" | "rowid" | "type" | "float" | "real" | "varchar" | "any" - ) ^^ { Type.fromString(_) } + ) ^^ { BaseType.fromString(_).get } def typeLeaf: Parser[Expression] = exprType ^^ { (t) => TypePrimitive(t) } diff --git a/src/main/scala/mimir/plot/Heuristics.scala b/src/main/scala/mimir/plot/Heuristics.scala index 1291e7fa..89dbc7d1 100644 --- a/src/main/scala/mimir/plot/Heuristics.scala +++ b/src/main/scala/mimir/plot/Heuristics.scala @@ -15,11 +15,11 @@ object Heuristics ): (Operator, Seq[Plot.Line], Plot.Config) = { //if no lines are specified, try to find the best ones - val columns = db.typechecker.schemaOf(dataQuery) + val columns = db.typechecker.baseSchemaOf(dataQuery) val columnMap = columns.toMap val numericColumns = columns.toSeq - .filter { t => Type.isNumeric(t._2) } + .filter { _._2.isNumeric } .map { _._1 } //if that comes up with nothing either, then throw an exception if(numericColumns.isEmpty){ @@ -39,7 +39,9 @@ object Heuristics // TODO: Plug DetectSeries in here. logger.info(s"No explicit columns given, implicitly using X = $x, Y = [${numericColumns.tail.mkString(", ")}]") val commonType = - Typechecker.leastUpperBound(numericColumns.tail.map { y => columnMap(y) }) + Typechecker.leastUpperBound( + numericColumns.tail.map { y => columnMap(y) }:Seq[BaseType] + ) ( dataQuery, numericColumns.tail.map { y => @@ -48,7 +50,6 @@ object Heuristics Map( "XLABEL" -> StringPrimitive(x) ) ++ (commonType match { - case Some(TUser(utype)) => Map("YLABEL" -> StringPrimitive(utype)) case Some(TDate() ) => Map("YLABEL" -> StringPrimitive("Date")) case Some(TTimestamp()) => Map("YLABEL" -> StringPrimitive("Time")) case _ => Map() diff --git a/src/main/scala/mimir/provenance/Provenance.scala b/src/main/scala/mimir/provenance/Provenance.scala index e36f6e70..66428271 100644 --- a/src/main/scala/mimir/provenance/Provenance.scala +++ b/src/main/scala/mimir/provenance/Provenance.scala @@ -16,10 +16,6 @@ object Provenance extends LazyLogging { val mergeRowIdFunction = "MIMIR_MAKE_ROWID" val rowidColnameBase = "MIMIR_ROWID" - - def compileGProM(oper: Operator): (Operator, Seq[String]) = { - OperatorTranslation.compileProvenanceWithGProM(oper) - } def compile(oper: Operator): (Operator, Seq[String]) = { @@ -348,7 +344,7 @@ object Provenance extends LazyLogging { } case Aggregate(gbCols, aggCols, src) => - val sch = db.typechecker.schemaOf(src).toMap + val sch = db.typechecker.baseSchemaOf(src).toMap val castTokenValues = gbCols.map { col => (col.name, Cast(sch(col.name), rowIds(col.name))) }.toMap diff --git a/src/main/scala/mimir/serialization/Json.scala b/src/main/scala/mimir/serialization/Json.scala index c7c5fb97..3e028b0b 100644 --- a/src/main/scala/mimir/serialization/Json.scala +++ b/src/main/scala/mimir/serialization/Json.scala @@ -4,6 +4,7 @@ import play.api.libs.json._ import mimir.Database import mimir.algebra._ +import mimir.algebra.typeregistry.TypeRegistry import mimir.util._ import mimir.views.ViewAnnotation @@ -171,43 +172,43 @@ object Json } } - def toOperator(json: JsValue): Operator = + def toOperator(json: JsValue, types:TypeRegistry): Operator = { val elems = json.asInstanceOf[JsObject].value// .asInstanceOf[JsObject].value elems("type").asInstanceOf[JsString].value match { case "aggregate" => Aggregate( - elems("gb_columns").asInstanceOf[JsArray].value.map { toExpression(_).asInstanceOf[Var] }, + elems("gb_columns").asInstanceOf[JsArray].value.map { toExpression(_, types).asInstanceOf[Var] }, elems("agg_columns").asInstanceOf[JsArray].value.map { fieldJson => val fields = fieldJson.asInstanceOf[JsObject].value AggFunction( fields("function").asInstanceOf[JsString].value, fields("distinct").asInstanceOf[JsBoolean].value, - fields("args").asInstanceOf[JsArray].value.map { toExpression(_) }, + fields("args").asInstanceOf[JsArray].value.map { toExpression(_, types) }, fields("alias").asInstanceOf[JsString].value ) }, - toOperator(elems("source")) + toOperator(elems("source"), types) ) case "annotate" => Annotate( - toOperator(elems("source")), + toOperator(elems("source"), types), elems("annotations").asInstanceOf[JsArray].value.map( annot => { val nameAnnot = annot.asInstanceOf[JsObject].value - ( nameAnnot("name").asInstanceOf[JsString].value, toAnnotation( nameAnnot("annotation").asInstanceOf[JsObject]) ) + ( nameAnnot("name").asInstanceOf[JsString].value, toAnnotation( nameAnnot("annotation").asInstanceOf[JsObject], types) ) }) ) case "join" => Join( - toOperator(elems("left")), - toOperator(elems("right")) + toOperator(elems("left"), types), + toOperator(elems("right"), types) ) case "join_left_outer" => LeftOuterJoin( - toOperator(elems("left")), - toOperator(elems("right")), - toExpression(elems("condition")) + toOperator(elems("left"), types), + toOperator(elems("right"), types), + toExpression(elems("condition"), types) ) case "limit" => Limit( @@ -221,7 +222,7 @@ object Json case JsNull => None case _ => throw new RAException("Invalid limit clause in JSON") }, - toOperator(elems("source")) + toOperator(elems("source"), types) ) case "project" => Project( @@ -230,10 +231,10 @@ object Json ProjectArg( fields("name").asInstanceOf[JsString].value, - toExpression(fields("expression")) + toExpression(fields("expression"), types) ) }, - toOperator(elems("source")) + toOperator(elems("source"), types) ) case "sort" => Sort( @@ -241,23 +242,23 @@ object Json val fields = fieldJson.asInstanceOf[JsObject].value SortColumn( - toExpression(fields("expression")), + toExpression(fields("expression"), types), fields("ascending").asInstanceOf[JsBoolean].value ) }, - toOperator(elems("source")) + toOperator(elems("source"), types) ) case "select" => Select( - toExpression(elems("condition")), - toOperator(elems("source")) + toExpression(elems("condition"), types), + toOperator(elems("source"), types) ) case "table_hardcoded" => - val schema = toSchema(elems("schema")) + val schema = toSchema(elems("schema"), types) HardTable( schema, elems("data").as[JsArray].value.map { rowJS => - rowJS.as[JsArray].value.zipWithIndex.map { vJS => toPrimitive(schema(vJS._2)._2, vJS._1) } + rowJS.as[JsArray].value.zipWithIndex.map { vJS => toPrimitive(types.rootType(schema(vJS._2)._2), vJS._1) } } ) @@ -265,21 +266,21 @@ object Json Table( elems("table").asInstanceOf[JsString].value, elems("alias").asInstanceOf[JsString].value, - toSchema(elems("schema")), + toSchema(elems("schema"), types), elems("metadata").asInstanceOf[JsArray].value.map { metaJson => val meta = metaJson.asInstanceOf[JsObject].value ( meta("alias").asInstanceOf[JsString].value, - toExpression(meta("value")), - toType(meta("type")) + toExpression(meta("value"), types), + toType(meta("type"), types) ) } ) case "table_view" => View( elems("name").asInstanceOf[JsString].value, - toOperator(elems("query")), + toOperator(elems("query"), types), elems("annotations").asInstanceOf[JsArray].value.map { annot => ViewAnnotation.withName(annot.asInstanceOf[JsString].value) }.toSet @@ -288,15 +289,15 @@ object Json AdaptiveView( elems("model").asInstanceOf[JsString].value, elems("name").asInstanceOf[JsString].value, - toOperator(elems("query")), + toOperator(elems("query"), types), elems("annotations").asInstanceOf[JsArray].value.map { annot => ViewAnnotation.withName(annot.asInstanceOf[JsString].value) }.toSet ) case "union" => Union( - toOperator(elems("left")), - toOperator(elems("right")) + toOperator(elems("left"), types), + toOperator(elems("right"), types) ) } @@ -309,14 +310,14 @@ object Json "type" -> ofType(annot.typ), "expression" -> ofExpression(annot.expr) )) - def toAnnotation(json: JsValue): AnnotateArg = + def toAnnotation(json: JsValue, types: TypeRegistry): AnnotateArg = { val fields = json.asInstanceOf[JsObject].value AnnotateArg( toAnnotationType(fields("annotation_type")), fields("name").asInstanceOf[JsString].value, - toType(fields("type")), - toExpression(fields("expression")) + toType(fields("type"), types), + toExpression(fields("expression"), types) ) } @@ -403,7 +404,7 @@ object Json } } - def toExpression(json: JsValue): Expression = + def toExpression(json: JsValue, types: TypeRegistry): Expression = { val fields = json.asInstanceOf[JsObject].value fields("type").asInstanceOf[JsString].value match { @@ -411,38 +412,38 @@ object Json case "arithmetic" => Arithmetic( Arith.withName(fields("op").asInstanceOf[JsString].value), - toExpression(fields("left")), - toExpression(fields("right")) + toExpression(fields("left"), types), + toExpression(fields("right"), types) ) case "comparison" => Comparison( Cmp.withName(fields("op").asInstanceOf[JsString].value), - toExpression(fields("left")), - toExpression(fields("right")) + toExpression(fields("left"), types), + toExpression(fields("right"), types) ) case "conditional" => Conditional( - toExpression(fields("if")), - toExpression(fields("then")), - toExpression(fields("else")) + toExpression(fields("if"), types), + toExpression(fields("then"), types), + toExpression(fields("else"), types) ) case "function" => Function( fields("name").asInstanceOf[JsString].value, - toExpressionList(fields("args")) + toExpressionList(fields("args"), types) ) case "is_null" => - IsNullExpression(toExpression(fields("arg"))) + IsNullExpression(toExpression(fields("arg"), types)) case "not" => - Not(toExpression(fields("arg"))) + Not(toExpression(fields("arg"), types)) case "jdbc_var" => - JDBCVar(toType(fields("var_type"))) + JDBCVar(toType(fields("var_type"), types)) case "var" => Var(fields("name").asInstanceOf[JsString].value) @@ -454,32 +455,32 @@ object Json VGTerm( fields("model").asInstanceOf[JsString].value, fields("var_index").asInstanceOf[JsNumber].value.toLong.toInt, - toExpressionList(fields("arguments")), - toExpressionList(fields("hints")) + toExpressionList(fields("arguments"), types), + toExpressionList(fields("hints"), types) ) - // fall back to treating it as a primitive type + // fall back to treating it as a primitive constant case t => - toPrimitive(Type.fromString(t), fields("value")) + toPrimitive(types.rootType(types.fromString(t)), fields("value")) } } def ofExpressionList(e: Seq[Expression]): JsArray = JsArray(e.map { ofExpression(_) }) - def toExpressionList(json: JsValue): Seq[Expression] = - json.asInstanceOf[JsArray].value.map { toExpression(_) } + def toExpressionList(json: JsValue, types: TypeRegistry): Seq[Expression] = + json.asInstanceOf[JsArray].value.map { toExpression(_, types) } def ofSchema(schema: Seq[(String,Type)]): JsArray = JsArray(schema.map { case (name, t) => JsObject(Map("name" -> JsString(name), "type" -> ofType(t))) }) - def toSchema(json: JsValue): Seq[(String,Type)] = + def toSchema(json: JsValue, types: TypeRegistry): Seq[(String,Type)] = json.asInstanceOf[JsArray].value.map { elem => val fields = elem.asInstanceOf[JsObject].value ( fields("name").asInstanceOf[JsString].value, - toType(fields("type")) + toType(fields("type"), types) ) } @@ -490,10 +491,19 @@ object Json ViewAnnotation.withName(json.asInstanceOf[JsString].value) def ofType(t: Type): JsValue = - JsString(Type.toString(t)) + JsString(t.toString) - def toType(json: JsValue): Type = - Type.fromString(json.asInstanceOf[JsString].value) + def toType(json: JsValue, types: TypeRegistry): Type = + { + val name = json.asInstanceOf[JsString].value + BaseType.fromString(name) + .getOrElse { + if(types supportsUserType name) { TUser(name) } + else { + throw new IllegalArgumentException("Invalid Type: "+name) + } + } + } def ofPrimitive(p: PrimitiveValue): JsValue = { @@ -511,7 +521,7 @@ object Json } } - def toPrimitive(t: Type, json: JsValue): PrimitiveValue = + def toPrimitive(t: BaseType, json: JsValue): PrimitiveValue = { (json,t) match { case (JsNull, _) => NullPrimitive() diff --git a/src/main/scala/mimir/sql/GProMBackend.scala b/src/main/scala/mimir/sql/GProMBackend.scala index fa932c1a..bd1f170f 100644 --- a/src/main/scala/mimir/sql/GProMBackend.scala +++ b/src/main/scala/mimir/sql/GProMBackend.scala @@ -115,8 +115,14 @@ class GProMBackend(backend: String, filename: String, var gpromLogLevel : Int) def dropDB():Unit = sparkBackend.dropDB() def materializeView(name:String): Unit = sparkBackend.materializeView(name) def createTable(tableName: String,oper: mimir.algebra.Operator): Unit = sparkBackend.createTable(tableName, oper) - def readDataSource(name: String,format: String,options: Map[String,String],schema: Option[Seq[(String, mimir.algebra.Type)]],load: Option[String]): Unit = sparkBackend.readDataSource(name, format, options, schema, load) - def getTableSchema(table: String): Option[Seq[(String, Type)]] = sparkBackend.getTableSchema(table) + def readDataSource( + name: String, + format: String, + options: Map[String,String], + schema: Option[Seq[(String, mimir.algebra.BaseType)]], + load: Option[String] + ): Unit = sparkBackend.readDataSource(name, format, options, schema, load) + def getTableSchema(table: String): Option[Seq[(String, BaseType)]] = sparkBackend.getTableSchema(table) def getAllTables(): Seq[String] = sparkBackend.getAllTables() @@ -124,8 +130,8 @@ class GProMBackend(backend: String, filename: String, var gpromLogLevel : Int) def canHandleVGTerms: Boolean = sparkBackend.canHandleVGTerms - def rowIdType: Type = sparkBackend.rowIdType - def dateType: Type = sparkBackend.dateType + def rowIdType: BaseType = sparkBackend.rowIdType + def dateType: BaseType = sparkBackend.dateType def specializeQuery(q: Operator, db: Database): Operator = sparkBackend.specializeQuery(q, db) def listTablesQuery: Operator = sparkBackend.listTablesQuery diff --git a/src/main/scala/mimir/sql/GProMMedadataLookup.scala b/src/main/scala/mimir/sql/GProMMedadataLookup.scala index d68d5022..199d6906 100644 --- a/src/main/scala/mimir/sql/GProMMedadataLookup.scala +++ b/src/main/scala/mimir/sql/GProMMedadataLookup.scala @@ -75,11 +75,11 @@ with LazyLogging case "&" => "DT_INT" case "MIMIR_MAKE_ROWID" => "DT_STRING" case _ => { - val argTypes = args.map(arg => getMimirTypeFromGProMDataTypeString(arg)) + val argTypes = args.map(arg => getMimirTypeFromGProMDataTypeString(arg)).map { db.types.rootType(_) } //logger.debug(s"Metadata lookup: function: $fName(${argSeq.mkString(",")})") getGProMDataTypeStringFromMimirType( fName match { case "sys_op_map_nonnull" => argTypes(0) - case "MIMIR_ENCODED_VGTERM" => db.typechecker.returnTypeOfFunction(VGTermFunctions.bestGuessVGTermFn,argTypes) + case "MIMIR_ENCODED_VGTERM" => db.typechecker.returnTypeOfFunction(VGTermFunctions.bestGuessVGTermFn, argTypes) case "UNCERT" => argTypes(0) case "LEAST" => db.typechecker.returnTypeOfFunction("MIN",argTypes) case _ => { @@ -199,6 +199,7 @@ with LazyLogging case TBool() => BoolPrimitive(false) case TRowId() => RowIdPrimitive("0") case TType() => TypePrimitive(TInt()) + case _ => ??? // This code is not in the active path for now. Suppressing warning } }) oName match { diff --git a/src/main/scala/mimir/sql/RABackend.scala b/src/main/scala/mimir/sql/RABackend.scala index 25d5cd64..1264ddd1 100644 --- a/src/main/scala/mimir/sql/RABackend.scala +++ b/src/main/scala/mimir/sql/RABackend.scala @@ -30,10 +30,10 @@ abstract class RABackend(val database:String) { def resultValue(sel:SelectBody):PrimitiveValue = resultRows(sel).head.head*/ - def readDataSource(name:String, format:String, options:Map[String, String], schema:Option[Seq[(String, Type)]], load:Option[String]) : Unit + def readDataSource(name:String, format:String, options:Map[String, String], schema:Option[Seq[(String, BaseType)]], load:Option[String]) : Unit - def getTableSchema(table: String): Option[Seq[(String, Type)]] + def getTableSchema(table: String): Option[Seq[(String, BaseType)]] def getAllTables(): Seq[String] @@ -43,8 +43,8 @@ abstract class RABackend(val database:String) { def close() def canHandleVGTerms: Boolean - def rowIdType: Type - def dateType: Type + def rowIdType: BaseType + def dateType: BaseType def specializeQuery(q: Operator, db: Database): Operator def listTablesQuery: Operator diff --git a/src/main/scala/mimir/sql/SparkBackend.scala b/src/main/scala/mimir/sql/SparkBackend.scala index 3e1e8491..2deec959 100644 --- a/src/main/scala/mimir/sql/SparkBackend.scala +++ b/src/main/scala/mimir/sql/SparkBackend.scala @@ -53,6 +53,8 @@ class SparkBackend(override val database:String, maintenance:Boolean = false) ex with LazyLogging { + var sparkTranslator: OperatorTranslation = null + var sparkSql : SQLContext = null //ExperimentalOptions.enable("remoteSpark") val envHasS3Keys = (Option(System.getenv("AWS_ACCESS_KEY_ID")), Option(System.getenv("AWS_SECRET_ACCESS_KEY"))) match { @@ -227,7 +229,7 @@ class SparkBackend(override val database:String, maintenance:Boolean = false) ex val fClassName = sparkSql.sparkSession.sessionState.catalog.lookupFunctionInfo(fidentifier).getClassName if(fClassName.startsWith("org.apache.spark.sql.catalyst.expressions.aggregate")){ Some((fidentifier.funcName.toUpperCase(), - (inputTypes:Seq[Type]) => { + (inputTypes:Seq[BaseType]) => { val inputs = inputTypes.map(inp => Literal(OperatorTranslation.getNative(NullPrimitive(), inp)).asInstanceOf[org.apache.spark.sql.catalyst.expressions.Expression]) val constructorTypes = inputs.map(inp => classOf[org.apache.spark.sql.catalyst.expressions.Expression]) val dt = OperatorTranslation.getMimirType( Class.forName(fClassName).getDeclaredConstructor(constructorTypes:_*).newInstance(inputs:_*) @@ -261,7 +263,7 @@ class SparkBackend(override val database:String, maintenance:Boolean = false) ex logger.trace(s"$compiledOp") logger.trace("------------------------------------------------------------") if(sparkSql == null) throw new Exception("There is no spark context") - sparkOper = OperatorTranslation.mimirOpToSparkOp(compiledOp) + sparkOper = sparkTranslator.mimirOpToSparkOp(compiledOp) logger.trace("------------------------ spark op --------------------------") logger.trace(s"$sparkOper") logger.trace("------------------------------------------------------------") @@ -290,7 +292,7 @@ class SparkBackend(override val database:String, maintenance:Boolean = false) ex } - def readDataSource(name:String, format:String, options:Map[String, String], schema:Option[Seq[(String, Type)]], load:Option[String]) = { + def readDataSource(name:String, format:String, options:Map[String, String], schema:Option[Seq[(String, BaseType)]], load:Option[String]) = { if(sparkSql == null) throw new Exception("There is no spark context") def copyToS3(file:String): String = { val accessKeyId = System.getenv("AWS_ACCESS_KEY_ID") @@ -415,7 +417,7 @@ class SparkBackend(override val database:String, maintenance:Boolean = false) ex } - def getTableSchema(table: String): Option[Seq[(String, Type)]] = { + def getTableSchema(table: String): Option[Seq[(String, BaseType)]] = { if(sparkSql == null) throw new Exception("There is no spark context") if(sparkSql.sparkSession.catalog.tableExists(table)) Some(sparkSql.sparkSession.catalog.listColumns(table).collect.map(col => (col.name, OperatorTranslation.getMimirType( OperatorTranslation.dataTypeFromHiveDataTypeString(col.dataType))))) @@ -441,8 +443,8 @@ class SparkBackend(override val database:String, maintenance:Boolean = false) ex } def canHandleVGTerms: Boolean = true - def rowIdType: Type = TString() - def dateType: Type = TDate() + def rowIdType: BaseType = TString() + def dateType: BaseType = TDate() def specializeQuery(q: Operator, db: mimir.Database): Operator = { q } diff --git a/src/main/scala/mimir/sql/SqlToRA.scala b/src/main/scala/mimir/sql/SqlToRA.scala index ddeba5b6..312c8d33 100644 --- a/src/main/scala/mimir/sql/SqlToRA.scala +++ b/src/main/scala/mimir/sql/SqlToRA.scala @@ -603,11 +603,13 @@ class SqlToRA(db: Database) throw new SQLException(s"Invalid CAST: $cast") } val target = convert(params(0)) - val t = params(1) match { - case s: StringValue => Type.fromString(s.toRawString) - case c: Column => Type.fromString(c.getColumnName) - case _ => throw new SQLException(s"Invalid CAST Type: $cast") - } + val t = (params(1) match { + case s: StringValue => BaseType.fromString(s.toRawString) + case c: Column => BaseType.fromString(c.getColumnName) + case _ => None + }).getOrElse { + throw new SQLException(s"Invalid CAST Type: $cast") + } return mimir.algebra.Function("CAST", Seq(target, TypePrimitive(t))) } diff --git a/src/main/scala/mimir/sql/sqlite/MimirFunction.scala b/src/main/scala/mimir/sql/sqlite/MimirFunction.scala index 6f193114..bbb3b307 100644 --- a/src/main/scala/mimir/sql/sqlite/MimirFunction.scala +++ b/src/main/scala/mimir/sql/sqlite/MimirFunction.scala @@ -3,7 +3,6 @@ package mimir.sql.sqlite; import java.sql.SQLException import mimir.algebra._ -import mimir.algebra.Type._ import mimir.util._ abstract class MimirFunction extends org.sqlite.Function @@ -11,7 +10,7 @@ abstract class MimirFunction extends org.sqlite.Function def value_mimir(idx: Int): PrimitiveValue = value_mimir(idx, TAny()) - def value_mimir(idx: Int, t:Type): PrimitiveValue = + def value_mimir(idx: Int, t:BaseType): PrimitiveValue = { if(value_type(idx) == SQLiteCompat.NULL){ NullPrimitive() } else { t match { @@ -40,13 +39,13 @@ abstract class MimirFunction extends org.sqlite.Function case RowIdPrimitive(r) => result(r) case t:TimestampPrimitive => result(t.asString) case i:IntervalPrimitive => result(i.asString) - case TypePrimitive(t) => result(Type.toString(t)) + case TypePrimitive(t) => result(t.toString) case NullPrimitive() => result() } } } -abstract class SimpleMimirFunction(argTypes: List[Type]) extends MimirFunction +abstract class SimpleMimirFunction(argTypes: List[BaseType]) extends MimirFunction { def apply(args: List[PrimitiveValue]): PrimitiveValue diff --git a/src/main/scala/mimir/sql/sqlite/SQLiteCompat.scala b/src/main/scala/mimir/sql/sqlite/SQLiteCompat.scala index 1ea05de2..97939984 100644 --- a/src/main/scala/mimir/sql/sqlite/SQLiteCompat.scala +++ b/src/main/scala/mimir/sql/sqlite/SQLiteCompat.scala @@ -61,8 +61,8 @@ object SQLiteCompat extends LazyLogging{ val name = x(1).asString.toUpperCase.trim val rawType = x(2).asString.trim val baseType = rawType.split("\\(")(0).trim - val inferredType = try { - Type.fromString(baseType) + val inferredType:BaseType = try { + BaseType.fromString(baseType).get } catch { case e:RAException => logger.warn(s"While getting schema for table '$table': ${e.getMessage}") @@ -300,7 +300,7 @@ object MimirCast extends org.sqlite.Function with LazyLogging { def xFunc(): Unit = { if (args != 2) { throw new java.sql.SQLDataException("NOT THE RIGHT NUMBER OF ARGS FOR MIMIRCAST, EXPECTED 2 IN FORM OF MIMIRCAST(COLUMN,TYPE)") } try { - val t = Type.toSQLiteType(value_int(1)) + val t = BaseType.idTypeOrder(value_int(1)) val v = value_text(0) logger.trace(s"Casting $v as $t") t match { @@ -323,45 +323,6 @@ object MimirCast extends org.sqlite.Function with LazyLogging { case TString() | TRowId() | TDate() | TTimestamp() => result(value_text(0)) - case TUser(name) => - val v:String = value_text(0) - if(v != null) { - Type.rootType(t) match { - case TRowId() => - result(value_text(0)) - case TString() | TDate() | TTimestamp() | TInterval() => - val txt = value_text(0) - if(TypeRegistry.matches(name, txt)){ - result(value_text(0)) - } else { - result() - } - case TInt() | TBool() => - if(TypeRegistry.matches(name, value_text(0))){ - result(value_int(0)) - } else { - result() - } - case TFloat() => - if(TypeRegistry.matches(name, value_text(0))){ - result(value_double(0)) - } else { - result() - } - case TAny() => - if(TypeRegistry.matches(name, value_text(0))){ - result(value_text(0)) - } else { - result() - } - case TUser(_) | TType() => - throw new Exception("In SQLiteCompat expected natural type but got: " + Type.rootType(t).toString()) - } - } - else{ - result() - } - case _ => result("I assume that you put something other than a number in, this functions works like, MIMIRCAST(column,type), the types are int values, 1 is int, 2 is double, 3 is string, and 5 is null, so MIMIRCAST(COL,1) is casting column 1 to int") // throw new java.sql.SQLDataException("Well here we are, I'm not sure really what went wrong but it happened in MIMIRCAST, maybe it was a type, good luck") diff --git a/src/main/scala/mimir/sql/sqlite/SpecializeForSQLite.scala b/src/main/scala/mimir/sql/sqlite/SpecializeForSQLite.scala index 834d77cf..95fea1cc 100644 --- a/src/main/scala/mimir/sql/sqlite/SpecializeForSQLite.scala +++ b/src/main/scala/mimir/sql/sqlite/SpecializeForSQLite.scala @@ -9,13 +9,13 @@ import mimir.util._ object SpecializeForSQLite { - def apply(e: Expression): Expression = + def apply(e: Expression, db: Database): Expression = { (e match { case Function("CAST", Seq(target, TypePrimitive(t))) => {//println("TYPE ID: "+t.id(t)) - Function("MIMIRCAST", Seq(target, IntPrimitive(Type.id(t))))} + Function("MIMIRCAST", Seq(target, IntPrimitive(db.types.idForType(db.types.rootType(t)).toLong)))} case Function("CAST", _) => throw new SQLException("Invalid CAST: "+e) @@ -49,7 +49,7 @@ object SpecializeForSQLite { case _ => e - }).recur( apply(_:Expression) ) + }).recur( apply(_:Expression, db) ) } def apply(agg: AggFunction, typeOf: Expression => Type): AggFunction = @@ -68,7 +68,7 @@ object SpecializeForSQLite { def apply(o: Operator, db: Database): Operator = { o.recurExpressions( - apply(_:Expression) + apply(_:Expression, db) ) match { case Aggregate(gb, agg, source) => { diff --git a/src/main/scala/mimir/util/JDBCUtils.scala b/src/main/scala/mimir/util/JDBCUtils.scala index 2f9703ee..58a9970d 100644 --- a/src/main/scala/mimir/util/JDBCUtils.scala +++ b/src/main/scala/mimir/util/JDBCUtils.scala @@ -9,7 +9,7 @@ import java.text.SimpleDateFormat object JDBCUtils { - def convertSqlType(t: Int): Type = { + def convertSqlType(t: Int): BaseType = { t match { case (java.sql.Types.FLOAT | java.sql.Types.DECIMAL | @@ -26,7 +26,7 @@ object JDBCUtils { } } - def convertMimirType(t: Type): Int = { + def convertMimirType(t: BaseType): Int = { t match { case TInt() => java.sql.Types.INTEGER case TFloat() => java.sql.Types.DOUBLE @@ -38,12 +38,11 @@ object JDBCUtils { case TBool() => java.sql.Types.INTEGER case TType() => java.sql.Types.VARCHAR case TInterval() => java.sql.Types.VARCHAR - case TUser(t) => convertMimirType(TypeRegistry.baseType(t)) } } - def convertFunction(t: Type, field: Integer, dateType: Type = TDate()): (ResultSet => PrimitiveValue) = + def convertFunction(t: BaseType, field: Integer, dateType: Type = TDate()): (ResultSet => PrimitiveValue) = { val checkNull: ((ResultSet, => PrimitiveValue) => PrimitiveValue) = { (r, call) => { @@ -60,7 +59,8 @@ object JDBCUtils { case TString() => (r) => checkNull(r, { StringPrimitive(r.getString(field)) }) case TRowId() => (r) => checkNull(r, { RowIdPrimitive(r.getString(field)) }) case TBool() => (r) => checkNull(r, { BoolPrimitive(r.getInt(field) != 0) }) - case TType() => (r) => checkNull(r, { TypePrimitive(Type.fromString(r.getString(field))) }) + case TType() => (r) => checkNull(r, { val name = r.getString(field) + TypePrimitive(BaseType.fromString(name).getOrElse { TUser(name) } ) }) case TDate() => dateType match { case TDate() => (r) => { val d = r.getDate(field); if(d == null){ NullPrimitive() } else { convertDate(d) } } @@ -87,11 +87,10 @@ object JDBCUtils { } case TInterval() => (r) => { TextUtils.parseInterval(r.getString(field)) } - case TUser(t) => convertFunction(TypeRegistry.baseType(t), field, dateType) } } - def convertField(t: Type, results: ResultSet, field: Integer, rowIdType: Type = TString()): PrimitiveValue = + def convertField(t: BaseType, results: ResultSet, field: Integer, rowIdType: Type = TString()): PrimitiveValue = { convertFunction( t match { @@ -144,13 +143,13 @@ object JDBCUtils { extractAllRows(results, schema) } - def extractAllRows(results: ResultSet, schema: Seq[Type]): JDBCResultSetIterable = + def extractAllRows(results: ResultSet, schema: Seq[BaseType]): JDBCResultSetIterable = { new JDBCResultSetIterable(results, schema) } } -class JDBCResultSetIterable(results: ResultSet, schema: Seq[Type]) +class JDBCResultSetIterable(results: ResultSet, schema: Seq[BaseType]) extends Iterator[Seq[PrimitiveValue]] { def next(): List[PrimitiveValue] = diff --git a/src/main/scala/mimir/util/JSONBuilder.scala b/src/main/scala/mimir/util/JSONBuilder.scala index 6dbf194c..4f5db199 100644 --- a/src/main/scala/mimir/util/JSONBuilder.scala +++ b/src/main/scala/mimir/util/JSONBuilder.scala @@ -81,7 +81,7 @@ object JSONBuilder { case d:Double => JsNumber(d) case b:Boolean => JsBoolean(b) case seq:Seq[Any] => listJs(seq) - case map:Map[String,Any] => dictJs(map) + case map:Map[_,_] => dictJs(map.asInstanceOf[Map[String,Any]]) case jsval:JsValue => jsval case prim:PrimitiveValue => primJs(prim) case _ => JsString(content.toString()) diff --git a/src/main/scala/mimir/util/LoadCSV.scala b/src/main/scala/mimir/util/LoadCSV.scala index 489ec68b..ca8f0163 100644 --- a/src/main/scala/mimir/util/LoadCSV.scala +++ b/src/main/scala/mimir/util/LoadCSV.scala @@ -32,8 +32,6 @@ object LoadCSV extends StrictLogging { def handleLoadTableRaw(db: Database, targetTable: String, sourceFile: File, options: Map[String,String] = Map()) = LoadData.handleLoadTableRaw(db, targetTable, sourceFile, options) - def handleLoadTableRaw(db: Database, targetTable: String, targetSchema:Option[Seq[(String,Type)]], sourceFile: File, options: Map[String,String]) = + def handleLoadTableRaw(db: Database, targetTable: String, targetSchema:Option[Seq[(String,BaseType)]], sourceFile: File, options: Map[String,String]) = LoadData.handleLoadTableRaw(db, targetTable, targetSchema, sourceFile, options, "csv") } - - diff --git a/src/main/scala/mimir/util/LoadData.scala b/src/main/scala/mimir/util/LoadData.scala index 4c804191..9e7ad454 100644 --- a/src/main/scala/mimir/util/LoadData.scala +++ b/src/main/scala/mimir/util/LoadData.scala @@ -21,12 +21,12 @@ object LoadData extends StrictLogging { def handleLoadTableRaw(db: Database, targetTable: String, sourceFile: File, options: Map[String,String] = Map(), format:String = "csv"){ //we need to handle making data publicly accessible here and adding it to spark for remote connections val path = if(sourceFile.getPath.contains(":/")) sourceFile.getPath.replaceFirst(":/", "://") else sourceFile.getAbsolutePath - db.backend.readDataSource(targetTable, format, options, db.tableSchema(targetTable), Some(path)) + db.backend.readDataSource(targetTable, format, options, db.tableBaseSchema(targetTable), Some(path)) } - def handleLoadTableRaw(db: Database, targetTable: String, targetSchema:Option[Seq[(String,Type)]], sourceFile: File, options: Map[String,String], format:String){ + def handleLoadTableRaw(db: Database, targetTable: String, targetSchema:Option[Seq[(String,BaseType)]], sourceFile: File, options: Map[String,String], format:String){ val schema = targetSchema match { - case None => db.tableSchema(targetTable) + case None => db.tableBaseSchema(targetTable) case _ => targetSchema } //we need to handle making data publicly accessible here and adding it to spark for remote connections diff --git a/src/main/scala/mimir/util/SealedSubclassEnumeration.scala b/src/main/scala/mimir/util/SealedSubclassEnumeration.scala new file mode 100644 index 00000000..26f3491d --- /dev/null +++ b/src/main/scala/mimir/util/SealedSubclassEnumeration.scala @@ -0,0 +1,45 @@ +package mimir.util + +import language.experimental.macros +import scala.reflect.macros.Context + +// Adapted from +// https://stackoverflow.com/questions/13671734/iteration-over-a-sealed-trait-in-scala +object SealedSubclassEnumeration { + def values[A]: Set[A] = macro values_impl[A] + + def values_impl[A: c.WeakTypeTag](c: Context) = { + import c.universe._ + + val symbol = weakTypeOf[A].typeSymbol + + if (!symbol.isClass) c.abort( + c.enclosingPosition, + "Can only enumerate values of a sealed trait or class." + ) else if (!symbol.asClass.isSealed) c.abort( + c.enclosingPosition, + "Can only enumerate values of a sealed trait or class." + ) else { + val children = symbol.asClass.knownDirectSubclasses.toList + + if (!children.forall(_.isModuleClass)) c.abort( + c.enclosingPosition, + "All children must be objects." + ) else c.Expr[Set[A]] { + def sourceModuleRef(sym: Symbol) = Ident( + sym.asInstanceOf[ + scala.reflect.internal.Symbols#Symbol + ].sourceModule.asInstanceOf[Symbol] + ) + + Apply( + Select( + reify(Set).tree, + newTermName("apply") + ), + children.map(sourceModuleRef(_)) + ) + } + } + } +} \ No newline at end of file diff --git a/src/main/scala/mimir/util/SparkUtils.scala b/src/main/scala/mimir/util/SparkUtils.scala index 725dd6bc..a9471f1f 100644 --- a/src/main/scala/mimir/util/SparkUtils.scala +++ b/src/main/scala/mimir/util/SparkUtils.scala @@ -16,7 +16,7 @@ import java.io.File object SparkUtils { //TODO:there are a bunch of hacks in this conversion function because type conversion in operator translator // needs to be done correctly - def convertFunction(t: Type, field: Integer, dateType: Type = TDate()): (Row => PrimitiveValue) = + def convertFunction(t: BaseType, field: Integer, dateType: Type = TDate()): (Row => PrimitiveValue) = { val checkNull: ((Row, => PrimitiveValue) => PrimitiveValue) = { (r, call) => { @@ -70,7 +70,8 @@ object SparkUtils { } } } }) - case TType() => (r) => checkNull(r, { TypePrimitive(Type.fromString(r.getString(field))) }) + case TType() => (r) => checkNull(r, { val name = r.getString(field) + TypePrimitive(BaseType.fromString(name).getOrElse { TUser(name) }) }) case TDate() => dateType match { case TDate() => (r) => { val d = r.getDate(field); if(d == null){ NullPrimitive() } else { convertDate(d) } } @@ -97,11 +98,10 @@ object SparkUtils { } case TInterval() => (r) => { TextUtils.parseInterval(r.getString(field)) } - case TUser(t) => convertFunction(TypeRegistry.baseType(t), field, dateType) } } - def convertField(t: Type, results: Row, field: Integer, dateType: Type = TString()): PrimitiveValue = + def convertField(t: BaseType, results: Row, field: Integer, dateType: Type = TString()): PrimitiveValue = { convertFunction( t match { @@ -149,7 +149,7 @@ object SparkUtils { extractAllRows(results, OperatorTranslation.structTypeToMimirSchema(results.schema).map(_._2)) - def extractAllRows(results: DataFrame, schema: Seq[Type]): SparkDataFrameIterable = + def extractAllRows(results: DataFrame, schema: Seq[BaseType]): SparkDataFrameIterable = { new SparkDataFrameIterable(results.collect().iterator, schema) } @@ -161,12 +161,12 @@ object SparkUtils { val models = ClassFinder.concreteSubclasses("mimir.models.Model", classMap).map(clazz => Class.forName(clazz.name)).toSeq val operators = ClassFinder.concreteSubclasses("mimir.algebra.Operator", classMap).map(clazz => Class.forName(clazz.name)).toSeq val expressions = ClassFinder.concreteSubclasses("mimir.algebra.Expression", classMap).map(clazz => Class.forName(clazz.name)).toSeq - val miscClasses = Seq(Class.forName("org.opengis.referencing.datum.Ellipsoid"),Class.forName("org.geotools.referencing.datum.DefaultEllipsoid")) + val miscClasses = Seq[Class[_]](Class.forName("org.opengis.referencing.datum.Ellipsoid"),Class.forName("org.geotools.referencing.datum.DefaultEllipsoid")) (models ++ operators ++ expressions ++ miscClasses).toArray } } -class SparkDataFrameIterable(results: Iterator[Row], schema: Seq[Type]) +class SparkDataFrameIterable(results: Iterator[Row], schema: Seq[BaseType]) extends Iterator[Seq[PrimitiveValue]] { def next(): List[PrimitiveValue] = diff --git a/src/main/scala/mimir/util/SqlUtils.scala b/src/main/scala/mimir/util/SqlUtils.scala index 358b81e3..a4899ed6 100644 --- a/src/main/scala/mimir/util/SqlUtils.scala +++ b/src/main/scala/mimir/util/SqlUtils.scala @@ -200,6 +200,7 @@ object SqlUtils { case Some(tblSchmd) => tblSchmd case None => db.backend.getTableSchema(table.getName()) match { case Some(tblSch) => tblSch + case None => throw new SQLException("Unknown Table: "+table.getName()) } }).map(_._1).toList++List("ROWID") ) diff --git a/src/main/scala/mimir/util/TextUtils.scala b/src/main/scala/mimir/util/TextUtils.scala index 2cd370f2..c590d089 100644 --- a/src/main/scala/mimir/util/TextUtils.scala +++ b/src/main/scala/mimir/util/TextUtils.scala @@ -5,7 +5,7 @@ import mimir.algebra._ object TextUtils extends LazyLogging { - def parsePrimitive(t: Type, s: String): PrimitiveValue = + def parsePrimitive(t: BaseType, s: String): PrimitiveValue = { t match { case TInt() => IntPrimitive(java.lang.Long.parseLong(s)) @@ -20,9 +20,8 @@ object TextUtils extends LazyLogging { case "NO" | "FALSE" | "0" => BoolPrimitive(false) } case TRowId() => RowIdPrimitive(s) - case TType() => TypePrimitive(Type.fromString(s)) + case TType() => TypePrimitive(BaseType.fromString(s).getOrElse { TUser(s) }) case TAny() => throw new RAException("Can't cast string to TAny") - case TUser(t) => parsePrimitive(TypeRegistry.baseType(t), s) } } diff --git a/src/main/scala/mimir/views/ViewManager.scala b/src/main/scala/mimir/views/ViewManager.scala index 654bb554..50a9ce1c 100644 --- a/src/main/scala/mimir/views/ViewManager.scala +++ b/src/main/scala/mimir/views/ViewManager.scala @@ -103,7 +103,7 @@ class ViewManager(db:Database) extends LazyLogging { results.take(1).headOption.map(_.toSeq).map( { case Seq(StringPrimitive(s), IntPrimitive(meta)) => { - val query = Json.toOperator(Json.parse(s)) + val query = Json.toOperator(Json.parse(s), db.types) val isMaterialized = meta != 0 diff --git a/src/test/scala/mimir/algebra/SerializationSpec.scala b/src/test/scala/mimir/algebra/SerializationSpec.scala index 4b05280d..b340ac57 100644 --- a/src/test/scala/mimir/algebra/SerializationSpec.scala +++ b/src/test/scala/mimir/algebra/SerializationSpec.scala @@ -50,7 +50,7 @@ object SerializationSpec extends SQLTestSpecification("SerializationTest") { i = i + 1; val query = db.sql.convert(s) val serialized = Json.ofOperator(query) - val deserialized = Json.toOperator(serialized) + val deserialized = Json.toOperator(serialized, db.types) Some(deserialized must be equalTo query) } diff --git a/src/test/scala/mimir/algebra/gprom/OperatorTranslationSpec.scala b/src/test/scala/mimir/algebra/gprom/OperatorTranslationSpec.scala index d385d2e1..9309d4d7 100644 --- a/src/test/scala/mimir/algebra/gprom/OperatorTranslationSpec.scala +++ b/src/test/scala/mimir/algebra/gprom/OperatorTranslationSpec.scala @@ -15,7 +15,7 @@ import java.io.File import mimir.util.LoadCSV import mimir.provenance.Provenance -object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorTranslation") with BeforeAll with AfterAll { +object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMdb.gpromTranslator") with BeforeAll with AfterAll { args(skipAll = false) @@ -61,10 +61,10 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT sequential "Compile Provenance for Projections" >> { /*LoggerUtils.debug( - "mimir.algebra.gprom.OperatorTranslation" + "mimir.algebra.gprom.db.gpromTranslator" ){*/ val table = db.table("CQ") - val (oper, provCols) = OperatorTranslation.compileProvenanceWithGProM(table) + val (oper, provCols) = db.gpromTranslator.compileProvenanceWithGProM(table) provOp = oper; query(table){ _.toSeq.map { _.provenance.asString } must contain( @@ -79,7 +79,7 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT } "Compile Determinism for Projections" >> { - val (oper, colDet, rowDet) = OperatorTranslation.compileTaintWithGProM(provOp) + val (oper, colDet, rowDet) = db.gpromTranslator.compileTaintWithGProM(provOp) val table = db.table("CQ") query(table){ _.toList.map( row => { @@ -104,11 +104,11 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT "Compile Provenance for Aggregates" >> { //LoggerUtils.debug( - // "mimir.algebra.gprom.OperatorTranslation" + // "mimir.algebra.gprom.db.gpromTranslator" //){ val statements = db.parse("select COUNT(COMMENT_ARG_0) from CQ") val testOper = db.sql.convert(statements.head.asInstanceOf[Select]) - val (oper, provCols) = OperatorTranslation.compileProvenanceWithGProM(testOper) + val (oper, provCols) = db.gpromTranslator.compileProvenanceWithGProM(testOper) provOp = oper; provCols must contain("PROV_Q_MIMIR__ROWID") @@ -125,7 +125,7 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT } "Compile Determinism for Aggregates" >> { - val (oper, colDet, rowDet) = OperatorTranslation.compileTaintWithGProM(provOp) + val (oper, colDet, rowDet) = db.gpromTranslator.compileTaintWithGProM(provOp) val statements = db.parse("select COUNT(COMMENT_ARG_0) from CQ") val testOper = db.sql.convert(statements.head.asInstanceOf[Select]) query(testOper){ @@ -224,7 +224,7 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT val queryStr = descAndQuery._1._2 val testOper = db.select(queryStr) //gp.metadataLookupPlugin.setOper(testOper) - val gpromNode = OperatorTranslation.mimirOperatorToGProMList(testOper) + val gpromNode = db.gpromTranslator.mimirOperatorToGProMList(testOper) gpromNode.write() //val memctx = GProMWrapper.inst.gpromCreateMemContext() val nodeStr = GProMWrapper.inst.gpromNodeToString(gpromNode.getPointer()) @@ -250,7 +250,7 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT var operStr2 = testOper2.toString() //val memctx = GProMWrapper.inst.gpromCreateMemContext() val gpromNode = GProMWrapper.inst.rewriteQueryToOperatorModel(queryStr+";") - val testOper = OperatorTranslation.gpromStructureToMimirOperator(0, gpromNode, null) + val testOper = db.gpromTranslator.gpromStructureToMimirOperator(0, gpromNode, null) var operStr = testOper.toString() //GProMWrapper.inst.gpromFreeMemContext(memctx) val ret = operStr must be equalTo operStr2 or @@ -261,7 +261,7 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT } or { val optOper = GProMWrapper.inst.optimizeOperatorModel(gpromNode.getPointer) - val resOp = OperatorTranslation.gpromStructureToMimirOperator(0, optOper, null) + val resOp = db.gpromTranslator.gpromStructureToMimirOperator(0, optOper, null) getQueryResultsBackend(resOp.removeColumn(Provenance.rowidColnameBase)) must be equalTo getQueryResultsBackend(queryStr) } ret @@ -273,10 +273,10 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT val queryStr = descAndQuery._1._2 val testOper = db.select(queryStr) var operStr = testOper.toString() - val gpromNode = OperatorTranslation.mimirOperatorToGProMList(testOper) + val gpromNode = db.gpromTranslator.mimirOperatorToGProMList(testOper) gpromNode.write() //val memctx = GProMWrapper.inst.gpromCreateMemContext() - val testOper2 = OperatorTranslation.gpromStructureToMimirOperator(0, gpromNode, null) + val testOper2 = db.gpromTranslator.gpromStructureToMimirOperator(0, gpromNode, null) var operStr2 = testOper2.toString() //GProMWrapper.inst.gpromFreeMemContext(memctx) val ret = operStr must be equalTo operStr2 or @@ -287,7 +287,7 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT } or { val optOper = GProMWrapper.inst.optimizeOperatorModel(gpromNode.getPointer) - val resOp = OperatorTranslation.gpromStructureToMimirOperator(0, optOper, null) + val resOp = db.gpromTranslator.gpromStructureToMimirOperator(0, optOper, null) getQueryResultsBackend(resOp.removeColumn(Provenance.rowidColnameBase)) must be equalTo getQueryResultsBackend(queryStr) } ret @@ -299,10 +299,10 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT val queryStr = descAndQuery._1._2 //val memctx = GProMWrapper.inst.gpromCreateMemContext() val gpromNode = GProMWrapper.inst.rewriteQueryToOperatorModel(queryStr+";") - val testOper = OperatorTranslation.gpromStructureToMimirOperator(0, gpromNode, null) + val testOper = db.gpromTranslator.gpromStructureToMimirOperator(0, gpromNode, null) val nodeStr = GProMWrapper.inst.gpromNodeToString(gpromNode.getPointer()) //GProMWrapper.inst.gpromFreeMemContext(memctx) - val gpromNode2 = OperatorTranslation.mimirOperatorToGProMList(testOper) + val gpromNode2 = db.gpromTranslator.mimirOperatorToGProMList(testOper) gpromNode2.write() //GProMWrapper.inst.gpromCreateMemContext() val nodeStr2 = GProMWrapper.inst.gpromNodeToString(gpromNode2.getPointer()) @@ -325,12 +325,12 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT val queryStr = descAndQuery._1._2 val testOper = db.select(queryStr) - val timeForRewriteThroughOperatorTranslation = time { - val gpromNode = OperatorTranslation.mimirOperatorToGProMList(testOper) + val timeForRewriteThroughOperatorTranslator = time { + val gpromNode = db.gpromTranslator.mimirOperatorToGProMList(testOper) gpromNode.write() //val smemctx = GProMWrapper.inst.gpromCreateMemContext() val gpromNode2 = GProMWrapper.inst.provRewriteOperator(gpromNode.getPointer()) - val testOper2 = OperatorTranslation.gpromStructureToMimirOperator(0, gpromNode2, null) + val testOper2 = db.gpromTranslator.gpromStructureToMimirOperator(0, gpromNode2, null) val operStr = ""//testOper2.toString() //GProMWrapper.inst.gpromFreeMemContext(smemctx) operStr @@ -344,9 +344,9 @@ object OperatorTranslationSpec extends GProMSQLTestSpecification("GProMOperatorT testOper2.toString()*/"" } - //timeForRewriteThroughOperatorTranslation._1 must be equalTo timeForRewriteThroughSQL._1 - //println(s"via SQL: ${timeForRewriteThroughSQL._2} via RA: ${timeForRewriteThroughOperatorTranslation._2}") - val ret = (timeForRewriteThroughOperatorTranslation._2 should be lessThan timeForRewriteThroughSQL._2) or (timeForRewriteThroughOperatorTranslation._2 should be lessThan (timeForRewriteThroughSQL._2*10)) + //timeForRewriteThroughdb.gpromTranslator._1 must be equalTo timeForRewriteThroughSQL._1 + //println(s"via SQL: ${timeForRewriteThroughSQL._2} via RA: ${timeForRewriteThroughdb.gpromTranslator._2}") + val ret = (timeForRewriteThroughOperatorTranslator._2 should be lessThan timeForRewriteThroughSQL._2) or (timeForRewriteThroughOperatorTranslator._2 should be lessThan (timeForRewriteThroughSQL._2*10)) ret } } diff --git a/src/test/scala/mimir/demo/MimirGProMDemo.scala b/src/test/scala/mimir/demo/MimirGProMDemo.scala index 1f9ebcfd..45278dd8 100644 --- a/src/test/scala/mimir/demo/MimirGProMDemo.scala +++ b/src/test/scala/mimir/demo/MimirGProMDemo.scala @@ -56,7 +56,7 @@ object MimirGProMDemo extends GProMSQLTestSpecification("MimirGProMDemo") // "mimir.algebra.gprom.OperatorTranslation" //){ val table = db.table("MVQ") - val (oper, provCols) = OperatorTranslation.compileProvenanceWithGProM(table) + val (oper, provCols) = db.gpromTranslator.compileProvenanceWithGProM(table) provOp = oper; query(table){ _.toSeq.map { _.provenance.asString } must contain( @@ -71,7 +71,7 @@ object MimirGProMDemo extends GProMSQLTestSpecification("MimirGProMDemo") } "Compile Determinism for Mimir Operators" >> { - val (oper, colDet, rowDet) = OperatorTranslation.compileTaintWithGProM(provOp) + val (oper, colDet, rowDet) = db.gpromTranslator.compileTaintWithGProM(provOp) val table = db.table("MVQ") query(table){ _.toList.map( row => { @@ -105,7 +105,7 @@ object MimirGProMDemo extends GProMSQLTestSpecification("MimirGProMDemo") println("------------mimir Op Json-----------------") println(mimir.serialization.Json.ofOperator(oper).toString) println("------------------------------------------")*/ - val gpromNode = TranslationUtils.scalaListToGProMList(Seq(OperatorTranslation.mimirOperatorToGProMOperator(oper))) + val gpromNode = TranslationUtils.scalaListToGProMList(Seq(db.gpromTranslator.mimirOperatorToGProMOperator(oper))) gpromNode.write() val gpNodeStr = GProMWrapper.inst.gpromNodeToString(gpromNode.getPointer()) .replaceAll("ADDRESS: '[a-z0-9]+'", "ADDRESS: ''") @@ -121,7 +121,7 @@ object MimirGProMDemo extends GProMSQLTestSpecification("MimirGProMDemo") def transGOpPrint(oper:Operator) : String = { org.gprom.jdbc.jna.GProM_JNA.GC_LOCK.synchronized{ val memctx = GProMWrapper.inst.gpromCreateMemContext() - val gpromNode = TranslationUtils.scalaListToGProMList(Seq(OperatorTranslation.mimirOperatorToGProMOperator(oper))) + val gpromNode = TranslationUtils.scalaListToGProMList(Seq(db.gpromTranslator.mimirOperatorToGProMOperator(oper))) gpromNode.write() val gpNodeStr = GProMWrapper.inst.gpromNodeToString(gpromNode.getPointer()) .replaceAll("ADDRESS: '[a-z0-9]+'", "ADDRESS: ''") @@ -129,7 +129,7 @@ object MimirGProMDemo extends GProMSQLTestSpecification("MimirGProMDemo") /*println("---------------gprom Op-------------------") println(gpNodeStr) println("------------------------------------------")*/ - val opOut = OperatorTranslation.gpromStructureToMimirOperator(0, gpromNode, null) + val opOut = db.gpromTranslator.gpromStructureToMimirOperator(0, gpromNode, null) /*println("--------------mimir Op--------------------") println(opOut.toString()) println("------------mimir Op Json-----------------") diff --git a/src/test/scala/mimir/demo/SimpleDemoScript.scala b/src/test/scala/mimir/demo/SimpleDemoScript.scala index f0563a28..abecd8e2 100644 --- a/src/test/scala/mimir/demo/SimpleDemoScript.scala +++ b/src/test/scala/mimir/demo/SimpleDemoScript.scala @@ -95,7 +95,7 @@ object SimpleDemoScript query("SELECT * FROM RATINGS1 WHERE RATING > 4;"){ _.toSeq must have size(2) } query("SELECT * FROM RATINGS2;"){ _.toSeq must have size(3) } db.typechecker.schemaOf(select("SELECT * FROM RATINGS2;")). - map(_._2).map(Type.rootType _) must be equalTo List(TString(), TFloat(), TFloat()) + map(_._2).map(db.types.rootType _) must be equalTo List(TString(), TFloat(), TFloat()) } "Create and Query Domain Constraint Repair Lenses" >> { diff --git a/src/test/scala/mimir/lenses/TypeInferenceSpec.scala b/src/test/scala/mimir/lenses/TypeInferenceSpec.scala index 98552bec..501d4236 100644 --- a/src/test/scala/mimir/lenses/TypeInferenceSpec.scala +++ b/src/test/scala/mimir/lenses/TypeInferenceSpec.scala @@ -32,8 +32,8 @@ object TypeInferenceSpec "Detect Timestamps Correctly" >> { - Type.matches(TTimestamp(), "2014-06-15 08:23:19") must beTrue - Type.matches(TTimestamp(), "2013-10-07 08:23:19.120") must beTrue + db.types.testForTypes("2014-06-15 08:23:19") must contain(TTimestamp()) + db.types.testForTypes("2013-10-07 08:23:19.120") must contain(TTimestamp()) db.loadTable("DETECTSERIESTEST1", new File("test/data/DetectSeriesTest1.csv")) diff --git a/src/test/scala/mimir/models/EditDistanceMatchModelSpec.scala b/src/test/scala/mimir/models/EditDistanceMatchModelSpec.scala index cc1ed01f..cbe494ca 100644 --- a/src/test/scala/mimir/models/EditDistanceMatchModelSpec.scala +++ b/src/test/scala/mimir/models/EditDistanceMatchModelSpec.scala @@ -7,7 +7,7 @@ import mimir.algebra._ object EditDistanceMatchModelSpec extends Specification { - def train(src: Map[String,Type], tgt: Map[String,Type]): Map[String,(Model,Int)] = + def train(src: Map[String,BaseType], tgt: Map[String,BaseType]): Map[String,(Model,Int)] = { EditDistanceMatchModel.train(null, "TEMP", Right(src.toList), Right(tgt.toList) diff --git a/src/test/scala/mimir/models/TypeInferenceModelSpec.scala b/src/test/scala/mimir/models/TypeInferenceModelSpec.scala index 96e1b43d..82582696 100644 --- a/src/test/scala/mimir/models/TypeInferenceModelSpec.scala +++ b/src/test/scala/mimir/models/TypeInferenceModelSpec.scala @@ -14,7 +14,14 @@ object TypeInferenceModelSpec extends SQLTestSpecification("TypeInferenceTests") def train(elems: List[String]): TypeInferenceModel = { - val model = new TypeInferenceModel("TEST_MODEL", Array("TEST_COLUMN"), 0.5, db.backend.asInstanceOf[BackendWithSparkContext].getSparkContext(),None) + val model = new TypeInferenceModel( + "TEST_MODEL", + Array("TEST_COLUMN"), + 0.5, + db.backend.asInstanceOf[BackendWithSparkContext].getSparkContext(), + None, + db.types.getSerializable + ) elems.foreach( model.learn(0, _) ) return model } @@ -53,7 +60,14 @@ object TypeInferenceModelSpec extends SQLTestSpecification("TypeInferenceTests") LoggerUtils.debug( "mimir.models.TypeInferenceModel" ){ - val model = new TypeInferenceModel("CPUSPEED:CORES", Array("CORES"), 0.5, db.backend.asInstanceOf[BackendWithSparkContext].getSparkContext(), Some(db.backend.execute(table("CPUSPEED")))) + val model = new TypeInferenceModel( + "CPUSPEED:CORES", + Array("CORES"), + 0.5, + db.backend.asInstanceOf[BackendWithSparkContext].getSparkContext(), + Some(db.backend.execute(table("CPUSPEED"))), + db.types.getSerializable + ) //model.train(db.backend.execute(table("CPUSPEED"))) guess(model) must be equalTo(TInt()) } diff --git a/src/test/scala/mimir/sql/SqlParserSpec.scala b/src/test/scala/mimir/sql/SqlParserSpec.scala index c0ebb5b7..48b61f58 100644 --- a/src/test/scala/mimir/sql/SqlParserSpec.scala +++ b/src/test/scala/mimir/sql/SqlParserSpec.scala @@ -60,10 +60,10 @@ object SqlParserSpec val sback = new SparkBackend(if(tempDB == null){ "testdb" } else { tempDB.toString.split("[\\\\/]").last.replaceAll("\\..*", "") }) val d = new Database(sback, j) try { + sback.sparkTranslator = d.sparkTranslator d.metadataBackend.open() d.backend.open(); SparkML(sback.sparkSql) - OperatorTranslation.db = d j.enableInlining(d) d.initializeDBForMimir(); } catch { @@ -73,7 +73,7 @@ object SqlParserSpec testData.foreach ( _ match { case ( tableName, tableData, tableCols ) => d.backend.dropTable(tableName) LoadCSV.handleLoadTableRaw(d, tableName, - Some(tableCols.map(el => (el._1, Type.fromString(el._2)))), tableData, Map()) + Some(tableCols.map(el => (el._1, BaseType.fromString(el._2).get))), tableData, Map()) }) d } catch { @@ -121,14 +121,14 @@ object SqlParserSpec cast1("string") must be equalTo TString() cast1("date") must be equalTo TDate() cast1("timestamp") must be equalTo TTimestamp() - cast1("flibble") must throwA[RAException] + cast1("flibble") must throwA[SQLException] cast2("int") must be equalTo TInt() cast2("double") must be equalTo TFloat() cast2("string") must be equalTo TString() cast2("date") must be equalTo TDate() cast2("timestamp") must be equalTo TTimestamp() - cast2("flibble") must throwA[RAException] + cast2("flibble") must throwA[SQLException] } "Parse trivial aggregate queries" in { diff --git a/src/test/scala/mimir/test/GProMSQLTestSpecification.scala b/src/test/scala/mimir/test/GProMSQLTestSpecification.scala index f777d7c9..4751b8ab 100644 --- a/src/test/scala/mimir/test/GProMSQLTestSpecification.scala +++ b/src/test/scala/mimir/test/GProMSQLTestSpecification.scala @@ -35,6 +35,7 @@ object GProMDBTestInstances val oldDBExists = dbFile.exists(); val backend = new GProMBackend(jdbcBackendMode, tempDBName+".db", 1) val tmpDB = new Database(backend, new JDBCMetadataBackend(jdbcBackendMode, tempDBName+".db")); + backend.sparkBackend.sparkTranslator = tmpDB.sparkTranslator if(shouldResetDB){ if(dbFile.exists()){ dbFile.delete(); } } @@ -57,8 +58,6 @@ object GProMDBTestInstances ExperimentalOptions.enable("GPROM-BACKEND") ExperimentalOptions.enable("GPROM-PROVENANCE") ExperimentalOptions.enable("GPROM-DETERMINISM") - mimir.algebra.gprom.OperatorTranslation.db = tmpDB - mimir.algebra.spark.OperatorTranslation.db = tmpDB databases.put(tempDBName, (tmpDB, backend)) (tmpDB, backend) } diff --git a/src/test/scala/mimir/test/PDBench.scala b/src/test/scala/mimir/test/PDBench.scala index 0f29e377..d66be06f 100644 --- a/src/test/scala/mimir/test/PDBench.scala +++ b/src/test/scala/mimir/test/PDBench.scala @@ -7,7 +7,7 @@ object PDBench { def isDownloaded = new File("test/pdbench").exists() - val tables = Map[String, (Seq[String], Seq[(String,String,Type,Double)])]( + val tables = Map[String, (Seq[String], Seq[(String,String,BaseType,Double)])]( "customer" -> ( Seq("custkey"), Seq( diff --git a/src/test/scala/mimir/test/SQLTestSpecification.scala b/src/test/scala/mimir/test/SQLTestSpecification.scala index 9004fee6..c8ecd573 100644 --- a/src/test/scala/mimir/test/SQLTestSpecification.scala +++ b/src/test/scala/mimir/test/SQLTestSpecification.scala @@ -46,13 +46,13 @@ object DBTestInstances val backend = new JDBCMetadataBackend(jdbcBackendMode, tempDBName+".db") val sback = new SparkBackend(tempDBName) val tmpDB = new Database(sback, backend); + sback.sparkTranslator = tmpDB.sparkTranslator if(shouldCleanupDB){ dbFile.deleteOnExit(); } tmpDB.metadataBackend.open() tmpDB.backend.open(); SparkML(sback.sparkSql) - OperatorTranslation.db = tmpDB if(shouldResetDB || !oldDBExists){ config.get("initial_db") match { case None => () diff --git a/src/test/scala/mimir/timing/vldb2017/MCDBTimingSpec.scala b/src/test/scala/mimir/timing/vldb2017/MCDBTimingSpec.scala index 11fd6432..3018947a 100644 --- a/src/test/scala/mimir/timing/vldb2017/MCDBTimingSpec.scala +++ b/src/test/scala/mimir/timing/vldb2017/MCDBTimingSpec.scala @@ -58,7 +58,7 @@ object MCDBTimingSpec } - def createTable(tableInfo:(String, String, Seq[(String, Type)], Double), tableSuffix: String = "_cleaned") = + def createTable(tableInfo:(String, String, Seq[(String, BaseType)], Double), tableSuffix: String = "_cleaned") = { val (baseTable, ddl, schema, timeout) = tableInfo diff --git a/src/test/scala/mimir/timing/vldb2017/VLDB2017TimingTest.scala b/src/test/scala/mimir/timing/vldb2017/VLDB2017TimingTest.scala index 03aa74fc..6f9851cd 100644 --- a/src/test/scala/mimir/timing/vldb2017/VLDB2017TimingTest.scala +++ b/src/test/scala/mimir/timing/vldb2017/VLDB2017TimingTest.scala @@ -29,7 +29,7 @@ abstract class VLDB2017TimingTest(dbName: String, config: Map[String,String]) val sampler = new SampleRows( (0 until 10).map { _ => random.nextLong }) - def loadTable(tableFields:(String, String, Type, Double), run:Int=1) = + def loadTable(tableFields:(String, String, BaseType, Double), run:Int=1) = { println(s"VLDB2017TimingTest.loadTable(${tableFields})") val (baseTable, columnName, columnType, timeout) = tableFields