diff --git a/crates/polyglot-sql/src/lineage.rs b/crates/polyglot-sql/src/lineage.rs index 1f47eeca..59d1136f 100644 --- a/crates/polyglot-sql/src/lineage.rs +++ b/crates/polyglot-sql/src/lineage.rs @@ -1843,6 +1843,120 @@ mod tests { assert_eq!(root_with_schema, Some(DataType::Text)); } + #[test] + fn test_lineage_with_schema_correlated_scalar_subquery() { + let query = + "SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1"; + let dialect = Dialect::get(DialectType::BigQuery); + let expr = dialect + .parse(query) + .unwrap() + .into_iter() + .next() + .expect("expected one expression"); + + let mut schema = MappingSchema::with_dialect(DialectType::BigQuery); + schema + .add_table( + "t1", + &[("id".into(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + schema + .add_table( + "t2", + &[ + ("id".into(), DataType::BigInt { length: None }), + ("val".into(), DataType::BigInt { length: None }), + ], + None, + ) + .expect("schema setup"); + + let node = lineage_with_schema( + "id", + &expr, + Some(&schema), + Some(DialectType::BigQuery), + false, + ) + .expect("lineage_with_schema should handle correlated scalar subqueries"); + + assert_eq!(node.name, "id"); + } + + #[test] + fn test_lineage_with_schema_join_using() { + let query = "SELECT a FROM t1 JOIN t2 USING(a)"; + let dialect = Dialect::get(DialectType::BigQuery); + let expr = dialect + .parse(query) + .unwrap() + .into_iter() + .next() + .expect("expected one expression"); + + let mut schema = MappingSchema::with_dialect(DialectType::BigQuery); + schema + .add_table( + "t1", + &[("a".into(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + schema + .add_table( + "t2", + &[("a".into(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + + let node = lineage_with_schema( + "a", + &expr, + Some(&schema), + Some(DialectType::BigQuery), + false, + ) + .expect("lineage_with_schema should handle JOIN USING"); + + assert_eq!(node.name, "a"); + } + + #[test] + fn test_lineage_with_schema_qualified_table_name() { + let query = "SELECT a FROM raw.t1"; + let dialect = Dialect::get(DialectType::BigQuery); + let expr = dialect + .parse(query) + .unwrap() + .into_iter() + .next() + .expect("expected one expression"); + + let mut schema = MappingSchema::with_dialect(DialectType::BigQuery); + schema + .add_table( + "raw.t1", + &[("a".into(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + + let node = lineage_with_schema( + "a", + &expr, + Some(&schema), + Some(DialectType::BigQuery), + false, + ) + .expect("lineage_with_schema should handle dotted schema.table names"); + + assert_eq!(node.name, "a"); + } + #[test] fn test_lineage_with_schema_none_matches_lineage() { let expr = parse("SELECT a FROM t"); diff --git a/crates/polyglot-sql/src/optimizer/qualify_columns.rs b/crates/polyglot-sql/src/optimizer/qualify_columns.rs index fd053922..babdffc2 100644 --- a/crates/polyglot-sql/src/optimizer/qualify_columns.rs +++ b/crates/polyglot-sql/src/optimizer/qualify_columns.rs @@ -222,6 +222,54 @@ pub fn validate_qualify_columns(expression: &Expression) -> QualifyColumnsResult Ok(()) } +/// Collect USING column names from JOIN clauses and register each with the +/// resolver, mapping them to the first FROM-clause source that contains them. +fn register_using_columns(select: &Select, resolver: &mut Resolver) { + let using_cols: Vec = select + .joins + .iter() + .flat_map(|j| j.using.iter().map(|id| id.name.clone())) + .collect(); + + if using_cols.is_empty() { + return; + } + + // Collect source names from the FROM clause (left side of joins) in order. + let from_sources: Vec = select + .from + .as_ref() + .map(|f| { + f.expressions + .iter() + .filter_map(|expr| match expr { + Expression::Table(t) => Some( + t.alias + .as_ref() + .map(|a| a.name.clone()) + .unwrap_or_else(|| t.name.name.clone()), + ), + _ => None, + }) + .collect() + }) + .unwrap_or_default(); + + for col_name in using_cols { + // Find the first FROM-clause source that contains this column. + let table = from_sources.iter().find_map(|source| { + resolver + .get_source_columns(source) + .ok() + .filter(|cols| cols.contains(&col_name)) + .map(|_| source.clone()) + }); + if let Some(table_name) = table { + resolver.add_using_column(col_name, table_name); + } + } +} + /// Qualify columns in a scope by adding table qualifiers fn qualify_columns_in_scope( select: &mut Select, @@ -229,6 +277,11 @@ fn qualify_columns_in_scope( resolver: &mut Resolver, allow_partial: bool, ) -> QualifyColumnsResult<()> { + // Register JOIN USING columns so the resolver can disambiguate them. + // USING columns exist in both joined tables; resolve each to a FROM-clause + // source (the left side of the join). + register_using_columns(select, resolver); + for expr in &mut select.expressions { qualify_columns_in_expression(expr, scope, resolver, allow_partial)?; } @@ -468,6 +521,12 @@ fn qualify_single_column( if let Some(table) = &col.table { let table_name = &table.name; if !scope.sources.contains_key(table_name) { + // Allow correlated references: if the table exists in the schema + // but not in the current scope, it may be referencing an outer scope + // (e.g., in a correlated scalar subquery). + if resolver.table_exists_in_schema(table_name) { + return Ok(()); + } return Err(QualifyColumnsError::UnknownTable(table_name.clone())); } @@ -2467,6 +2526,143 @@ mod tests { assert!(sql.contains("t.b")); } + #[test] + fn test_qualify_columns_join_using() { + let expr = parse("SELECT a FROM t1 JOIN t2 USING(a)"); + + let mut schema = MappingSchema::new(); + schema + .add_table( + "t1", + &[("a".to_string(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + schema + .add_table( + "t2", + &[("a".to_string(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + + let result = + qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify"); + let sql = gen(&result); + + // The USING column should be qualified with the left (FROM) table + assert!(sql.contains("t1.a"), "USING column should resolve to FROM table: {sql}"); + } + + #[test] + fn test_qualify_columns_join_using_multiple_columns() { + let expr = parse("SELECT a, b FROM t1 JOIN t2 USING(a, b)"); + + let mut schema = MappingSchema::new(); + schema + .add_table( + "t1", + &[ + ("a".to_string(), DataType::BigInt { length: None }), + ("b".to_string(), DataType::BigInt { length: None }), + ], + None, + ) + .expect("schema setup"); + schema + .add_table( + "t2", + &[ + ("a".to_string(), DataType::BigInt { length: None }), + ("b".to_string(), DataType::BigInt { length: None }), + ], + None, + ) + .expect("schema setup"); + + let result = + qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify"); + let sql = gen(&result); + + assert!(sql.contains("t1.a"), "USING column 'a' should resolve: {sql}"); + assert!(sql.contains("t1.b"), "USING column 'b' should resolve: {sql}"); + } + + #[test] + fn test_qualify_columns_qualified_table_name() { + let expr = parse("SELECT a FROM raw.t1"); + + let mut schema = MappingSchema::new(); + schema + .add_table( + "raw.t1", + &[("a".to_string(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + + let result = + qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify"); + let sql = gen(&result); + + assert!( + sql.contains("t1.a"), + "column should be qualified with table name: {sql}" + ); + } + + #[test] + fn test_qualify_columns_correlated_scalar_subquery() { + let expr = + parse("SELECT id, (SELECT AVG(val) FROM t2 WHERE t2.id = t1.id) AS avg_val FROM t1"); + + let mut schema = MappingSchema::new(); + schema + .add_table( + "t1", + &[("id".to_string(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + schema + .add_table( + "t2", + &[ + ("id".to_string(), DataType::BigInt { length: None }), + ("val".to_string(), DataType::BigInt { length: None }), + ], + None, + ) + .expect("schema setup"); + + let result = + qualify_columns(expr, &schema, &QualifyColumnsOptions::new()).expect("qualify"); + let sql = gen(&result); + + assert!(sql.contains("t1.id"), "outer column should be qualified: {sql}"); + assert!(sql.contains("t2.id"), "inner column should be qualified: {sql}"); + } + + #[test] + fn test_qualify_columns_rejects_unknown_table() { + let expr = parse("SELECT id FROM t1 WHERE nonexistent.col = 1"); + + let mut schema = MappingSchema::new(); + schema + .add_table( + "t1", + &[("id".to_string(), DataType::BigInt { length: None })], + None, + ) + .expect("schema setup"); + + let result = qualify_columns(expr, &schema, &QualifyColumnsOptions::new()); + assert!( + result.is_err(), + "should reject reference to table not in scope or schema" + ); + } + // ====================================================================== // quote_identifiers tests // ====================================================================== diff --git a/crates/polyglot-sql/src/resolver.rs b/crates/polyglot-sql/src/resolver.rs index 9d5db131..1c5599a8 100644 --- a/crates/polyglot-sql/src/resolver.rs +++ b/crates/polyglot-sql/src/resolver.rs @@ -10,7 +10,7 @@ //! Based on the Python implementation in `sqlglot/optimizer/resolver.py`. use crate::dialects::DialectType; -use crate::expressions::{Expression, Identifier}; +use crate::expressions::{Expression, Identifier, TableRef}; use crate::schema::Schema; use crate::scope::{Scope, SourceInfo}; use std::collections::{HashMap, HashSet}; @@ -54,6 +54,8 @@ pub struct Resolver<'a> { unambiguous_columns_cache: Option>, /// Cached set of all available columns all_columns_cache: Option>, + /// Columns disambiguated by JOIN USING: column_name -> table_name + using_column_tables: HashMap, } impl<'a> Resolver<'a> { @@ -67,9 +69,18 @@ impl<'a> Resolver<'a> { source_columns_cache: HashMap::new(), unambiguous_columns_cache: None, all_columns_cache: None, + using_column_tables: HashMap::new(), } } + /// Register a USING column with its resolved table. + /// This allows ambiguous columns from JOIN USING to be resolved. + pub fn add_using_column(&mut self, column_name: String, table_name: String) { + self.using_column_tables.insert(column_name, table_name); + // Invalidate caches since USING changes resolution + self.unambiguous_columns_cache = None; + } + /// Get the table for a column name. /// /// Returns the table name if it can be found/inferred. @@ -82,6 +93,11 @@ impl<'a> Resolver<'a> { return table_name; } + // Check if this column was disambiguated by a JOIN USING clause + if let Some(table) = self.using_column_tables.get(column_name) { + return Some(table.clone()); + } + // If schema inference is enabled and exactly one source has no schema, // assume the column belongs to that source if self.infer_schema { @@ -105,6 +121,12 @@ impl<'a> Resolver<'a> { self.get_table(column_name).map(Identifier::new) } + /// Check if a table exists in the schema (not necessarily in the current scope). + /// Used to detect correlated references to outer scope tables. + pub fn table_exists_in_schema(&self, table_name: &str) -> bool { + self.schema.column_names(table_name).is_ok() + } + /// Get all available columns across all sources in this scope pub fn all_columns(&mut self) -> &HashSet { if self.all_columns_cache.is_none() { @@ -148,8 +170,10 @@ impl<'a> Resolver<'a> { fn extract_columns_from_source(&self, source_info: &SourceInfo) -> ResolverResult> { let columns = match &source_info.expression { Expression::Table(table) => { - // For tables, try to get columns from schema - let table_name = table.name.name.clone(); + // For tables, try to get columns from schema. + // Build the fully qualified name (catalog.schema.table) to + // match how MappingSchema stores hierarchical keys. + let table_name = qualified_table_name(table); match self.schema.column_names(&table_name) { Ok(cols) => cols, Err(_) => Vec::new(), // Schema might not have this table @@ -387,6 +411,19 @@ pub fn is_column_ambiguous(scope: &Scope, schema: &dyn Schema, column_name: &str resolver.is_ambiguous(column_name) } +/// Build the fully qualified table name (catalog.schema.table) from a TableRef. +fn qualified_table_name(table: &TableRef) -> String { + let mut parts = Vec::new(); + if let Some(catalog) = &table.catalog { + parts.push(catalog.name.clone()); + } + if let Some(schema) = &table.schema { + parts.push(schema.name.clone()); + } + parts.push(table.name.name.clone()); + parts.join(".") +} + #[cfg(test)] mod tests { use super::*;