Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions meetings/2025-12-04.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
title: 2025-12-04 - HLSL Working Group Minutes
---

* Discussion topics
* Discuss the state of MatrixSingleSubscript expressions and gather ideas on how to support vector swizzle operartions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The introduction of the MatrixSingleSubscriptExpr AST type meant we needed a
custom emitter in EmitMatrixSingleSubscriptExpr which meant we needed a new
MatrixRow Lvalue type because ExtVectorElt had a requirment for constant
indicies. We can still use MakeExtVectorElt but only if the row index is
constant. For now if we don't have them then EmitExtVectorElementExpr fails.
Like so:

  if (Base.isMatrixRow())
    return EmitUnsupportedLValue(E, "Matrix single index swizzle");

That means today the following cases work

export float4 getMatrix(float4x4 M, int index) {
    return M[index];
}

export void setMatrix(out float4x4 M, int index, float4 V) {
    M[index] = V;
}

export void setMatrix2(out float4x4 M, float4 V) {
    M[3].abgr = V;
}

export void setMatrix3(out float4x4 M, float4 V) {
    M[1].rgba = V;
}

That also means this doesn't work

export void setMatrix4(out float4x4 M, int index, float4 V) {
    M[index].abgr = V;
}

export float3 getMatrix2(float4x4 M, int index) {
    return M[index].rgb;
}

This failing case while visually similar to swizzling on an array of vectors is
very different. for an array-of-vectors, the “hard part” (the dynamic index) is
already solved before the swizzle ever sees it. For matrices, you’re trying to
push that dynamic index into the swizzle layer, where it fundamentally doesn’t
fit. Example

float4 v[10];
float2 x = v[i].xy;

The AST just treates the array subscript as a base:

ArraySubscriptExpr( base = v, idx = i ) --> type float4
ExtVectorElementExpr( base = ArraySubscriptExpr, "xy" )

Crucially: the dynamic index i is already baked into the pointer. The LValue
itself doesn’t need to remember i separately. MakeExtVectorElt only needs:

  1. “where is the vector?” → baseLV.getAddress()
  2. “which components (xy)?” → constant mask {0, 1}

That’s why arrays-of-vectors don’t have your problem: the dynamic part (i) is
entirely handled by the array-subscript lvalue; the swizzle only deals with
constant component selection inside a single, already-chosen vector.

With matrices, what you want conceptually is very similar:

float4x4 M;
float2 x = M[row].xy;

But the the difference is Clang’s matrix extension doesn’t model a matrix as
“array of row-vectors” or “array of col-vectors” at the IR level. Its
represented as one singular vector. For example the ir would look like

instruction <16 x float>  // for a 4x4, flattened

Plan(s)

  1. Make matrices physically be “array of vectors”
    • We would essentially have to give up on using the MatrixType.
    • But if we changed matrix lowering for HLSL to something like: [NumRows x <NumCols x T>] then EmitMatrixSingleSubscriptExpr could mirror EmitArraySubscriptExpr:
      • Now RowIdx lives only in the GEP. From then on:
      • M[row] is just a Simple vector lvalue.
      • M[row].xy uses the exact same ExtVectorElt code as arrays-of-vectors.
  2. Keep flattened representation but add custom row abstraction
    • The “swizzle” logic has to special-case “base is matrix-row” and do the right gather/scatter.
    • in practice this will be less work than 1 but still significant as we are implementing a dynamic index swizzle
  3. support dynamic index swizzles for R-values only and change the spec?
  • With R values we can build the vector and then EmitExtVectorElementExpr will just work.
  • We aleady have L-value swizzle support for when row index is constant
From 2b06eed2274f0e07d74fc0811eb5683652f4c365 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <[email protected]>
Date: Mon, 24 Nov 2025 16:57:45 -0500
Subject: [PATCH 1/2] [HLSL][Matrix] Add support for single subscript accessor

fixes #166206
---
 clang/include/clang/AST/ComputeDependence.h   |  2 +
 clang/include/clang/AST/Expr.h                | 67 +++++++++++++++++
 clang/include/clang/AST/RecursiveASTVisitor.h |  1 +
 clang/include/clang/AST/Stmt.h                |  1 +
 clang/include/clang/Basic/StmtNodes.td        |  1 +
 clang/include/clang/Sema/Sema.h               |  3 +
 clang/lib/AST/ComputeDependence.cpp           |  4 +
 clang/lib/AST/Expr.cpp                        |  1 +
 clang/lib/AST/ExprClassification.cpp          |  3 +
 clang/lib/AST/ExprConstant.cpp                |  1 +
 clang/lib/AST/ItaniumMangle.cpp               |  9 +++
 clang/lib/AST/StmtPrinter.cpp                 |  8 ++
 clang/lib/AST/StmtProfile.cpp                 |  5 ++
 clang/lib/CodeGen/CGExpr.cpp                  | 74 +++++++++++++++++++
 clang/lib/CodeGen/CGExprScalar.cpp            | 35 +++++++++
 clang/lib/CodeGen/CGValue.h                   | 19 ++++-
 clang/lib/CodeGen/CodeGenFunction.h           |  1 +
 clang/lib/Sema/SemaExceptionSpec.cpp          |  1 +
 clang/lib/Sema/SemaExpr.cpp                   | 61 ++++++++++++++-
 clang/lib/Sema/TreeTransform.h                | 29 ++++++++
 clang/lib/Serialization/ASTReaderStmt.cpp     |  8 ++
 clang/lib/Serialization/ASTWriterStmt.cpp     |  9 +++
 clang/lib/StaticAnalyzer/Core/ExprEngine.cpp  |  5 ++
 clang/tools/libclang/CXCursor.cpp             |  5 ++
 24 files changed, 351 insertions(+), 2 deletions(-)

diff --git a/clang/include/clang/AST/ComputeDependence.h b/clang/include/clang/AST/ComputeDependence.h
index c298f2620f211..895105640b931 100644
--- a/clang/include/clang/AST/ComputeDependence.h
+++ b/clang/include/clang/AST/ComputeDependence.h
@@ -28,6 +28,7 @@ class ParenExpr;
 class UnaryOperator;
 class UnaryExprOrTypeTraitExpr;
 class ArraySubscriptExpr;
+class MatrixSingleSubscriptExpr;
 class MatrixSubscriptExpr;
 class CompoundLiteralExpr;
 class ImplicitCastExpr;
@@ -117,6 +118,7 @@ ExprDependence computeDependence(ParenExpr *E);
 ExprDependence computeDependence(UnaryOperator *E, const ASTContext &Ctx);
 ExprDependence computeDependence(UnaryExprOrTypeTraitExpr *E);
 ExprDependence computeDependence(ArraySubscriptExpr *E);
+ExprDependence computeDependence(MatrixSingleSubscriptExpr *E);
 ExprDependence computeDependence(MatrixSubscriptExpr *E);
 ExprDependence computeDependence(CompoundLiteralExpr *E);
 ExprDependence computeDependence(ImplicitCastExpr *E);
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 573cc72db35c6..16d9bbe8ff7c1 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -2790,6 +2790,73 @@ class ArraySubscriptExpr : public Expr {
   }
 };
 
+/// MatrixSingleSubscriptExpr - Matrix single subscript expression for the
+/// MatrixType extension when you want to get\set a vector from a Matrix.
+class MatrixSingleSubscriptExpr : public Expr {
+  enum { BASE, ROW_IDX, END_EXPR };
+  Stmt *SubExprs[END_EXPR];
+
+public:
+  /// matrix[row]
+  ///
+  /// \param Base        The matrix expression.
+  /// \param RowIdx      The row index expression.
+  /// \param T           The type of the row (usually a vector type).
+  /// \param RBracketLoc Location of the closing ']'.
+  MatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx, QualType T,
+                            SourceLocation RBracketLoc)
+      : Expr(MatrixSingleSubscriptExprClass, T,
+             Base->getValueKind(), // lvalue/rvalue follows the matrix base
+             OK_MatrixComponent) { // or OK_Ordinary/OK_VectorComponent if you
+                                   // prefer
+    SubExprs[BASE] = Base;
+    SubExprs[ROW_IDX] = RowIdx;
+    ArrayOrMatrixSubscriptExprBits.RBracketLoc = RBracketLoc;
+    setDependence(computeDependence(this));
+  }
+
+  /// Create an empty matrix single-subscript expression.
+  explicit MatrixSingleSubscriptExpr(EmptyShell Shell)
+      : Expr(MatrixSingleSubscriptExprClass, Shell) {}
+
+  Expr *getBase() { return cast<Expr>(SubExprs[BASE]); }
+  const Expr *getBase() const { return cast<Expr>(SubExprs[BASE]); }
+  void setBase(Expr *E) { SubExprs[BASE] = E; }
+
+  Expr *getRowIdx() { return cast<Expr>(SubExprs[ROW_IDX]); }
+  const Expr *getRowIdx() const { return cast<Expr>(SubExprs[ROW_IDX]); }
+  void setRowIdx(Expr *E) { SubExprs[ROW_IDX] = E; }
+
+  SourceLocation getBeginLoc() const LLVM_READONLY {
+    return getBase()->getBeginLoc();
+  }
+
+  SourceLocation getEndLoc() const { return getRBracketLoc(); }
+
+  SourceLocation getExprLoc() const LLVM_READONLY {
+    return getBase()->getExprLoc();
+  }
+
+  SourceLocation getRBracketLoc() const {
+    return ArrayOrMatrixSubscriptExprBits.RBracketLoc;
+  }
+  void setRBracketLoc(SourceLocation L) {
+    ArrayOrMatrixSubscriptExprBits.RBracketLoc = L;
+  }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == MatrixSingleSubscriptExprClass;
+  }
+
+  // Iterators
+  child_range children() {
+    return child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
+  }
+  const_child_range children() const {
+    return const_child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
+  }
+};
+
 /// MatrixSubscriptExpr - Matrix subscript expression for the MatrixType
 /// extension.
 /// MatrixSubscriptExpr can be either incomplete (only Base and RowIdx are set
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 8f427427d71ed..92409b72e4f0c 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -2893,6 +2893,7 @@ DEF_TRAVERSE_STMT(CXXMemberCallExpr, {})
 // over the children.
 DEF_TRAVERSE_STMT(AddrLabelExpr, {})
 DEF_TRAVERSE_STMT(ArraySubscriptExpr, {})
+DEF_TRAVERSE_STMT(MatrixSingleSubscriptExpr, {})
 DEF_TRAVERSE_STMT(MatrixSubscriptExpr, {})
 DEF_TRAVERSE_STMT(ArraySectionExpr, {})
 DEF_TRAVERSE_STMT(OMPArrayShapingExpr, {})
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index e1cca34d2212c..21d0a7dfe577c 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -530,6 +530,7 @@ class alignas(void *) Stmt {
   class ArrayOrMatrixSubscriptExprBitfields {
     friend class ArraySubscriptExpr;
     friend class MatrixSubscriptExpr;
+    friend class MatrixSingleSubscriptExpr;
 
     LLVM_PREFERRED_TYPE(ExprBitfields)
     unsigned : NumExprBits;
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index bf3686bb372d5..ada74807e56e2 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -74,6 +74,7 @@ def UnaryOperator : StmtNode<Expr>;
 def OffsetOfExpr : StmtNode<Expr>;
 def UnaryExprOrTypeTraitExpr : StmtNode<Expr>;
 def ArraySubscriptExpr : StmtNode<Expr>;
+def MatrixSingleSubscriptExpr : StmtNode<Expr>;
 def MatrixSubscriptExpr : StmtNode<Expr>;
 def ArraySectionExpr : StmtNode<Expr>;
 def OMPIteratorExpr : StmtNode<Expr>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 4a601a0eaf1b9..d4d5c3d8bed17 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -7406,6 +7406,9 @@ class Sema final : public SemaBase {
   ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc,
                                              Expr *Idx, SourceLocation RLoc);
 
+  ExprResult CreateBuiltinMatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx,
+                                                    SourceLocation RBLoc);
+
   ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
                                               Expr *ColumnIdx,
                                               SourceLocation RBLoc);
diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp
index 638080ea781a9..8429f17d26be5 100644
--- a/clang/lib/AST/ComputeDependence.cpp
+++ b/clang/lib/AST/ComputeDependence.cpp
@@ -115,6 +115,10 @@ ExprDependence clang::computeDependence(ArraySubscriptExpr *E) {
   return E->getLHS()->getDependence() | E->getRHS()->getDependence();
 }
 
+ExprDependence clang::computeDependence(MatrixSingleSubscriptExpr *E) {
+  return E->getBase()->getDependence() | E->getRowIdx()->getDependence();
+}
+
 ExprDependence clang::computeDependence(MatrixSubscriptExpr *E) {
   return E->getBase()->getDependence() | E->getRowIdx()->getDependence() |
          (E->getColumnIdx() ? E->getColumnIdx()->getDependence()
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index ca7f3e16a9276..b400b2a083d9b 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3789,6 +3789,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
 
   case ParenExprClass:
   case ArraySubscriptExprClass:
+  case MatrixSingleSubscriptExprClass:
   case MatrixSubscriptExprClass:
   case ArraySectionExprClass:
   case OMPArrayShapingExprClass:
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index aeacd0dc765ef..9995d1b411c5b 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -259,6 +259,9 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
     }
     return Cl::CL_LValue;
 
+  case Expr::MatrixSingleSubscriptExprClass:
+    return ClassifyInternal(Ctx, cast<MatrixSingleSubscriptExpr>(E)->getBase());
+
   // Subscripting matrix types behaves like member accesses.
   case Expr::MatrixSubscriptExprClass:
     return ClassifyInternal(Ctx, cast<MatrixSubscriptExpr>(E)->getBase());
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 11c5e1c6e90f4..52481dc71b75d 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -20667,6 +20667,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
   case Expr::ImaginaryLiteralClass:
   case Expr::StringLiteralClass:
   case Expr::ArraySubscriptExprClass:
+  case Expr::MatrixSingleSubscriptExprClass:
   case Expr::MatrixSubscriptExprClass:
   case Expr::ArraySectionExprClass:
   case Expr::OMPArrayShapingExprClass:
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 5572e0a7ae59c..cb71987fba766 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -5482,6 +5482,15 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
     break;
   }
 
+  case Expr::MatrixSingleSubscriptExprClass: {
+    NotPrimaryExpr();
+    const MatrixSingleSubscriptExpr *ME = cast<MatrixSingleSubscriptExpr>(E);
+    Out << "ix";
+    mangleExpression(ME->getBase());
+    mangleExpression(ME->getRowIdx());
+    break;
+  }
+
   case Expr::MatrixSubscriptExprClass: {
     NotPrimaryExpr();
     const MatrixSubscriptExpr *ME = cast<MatrixSubscriptExpr>(E);
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index ff8ca01ec5477..51b9c47f22ff0 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -1685,6 +1685,14 @@ void StmtPrinter::VisitArraySubscriptExpr(ArraySubscriptExpr *Node) {
   OS << "]";
 }
 
+void StmtPrinter::VisitMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *Node) {
+  PrintExpr(Node->getBase());
+  OS << "[";
+  PrintExpr(Node->getRowIdx());
+  OS << "]";
+}
+
 void StmtPrinter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *Node) {
   PrintExpr(Node->getBase());
   OS << "[";
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 4a8c638c85331..c7b7c65715dfc 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -1508,6 +1508,11 @@ void StmtProfiler::VisitArraySubscriptExpr(const ArraySubscriptExpr *S) {
   VisitExpr(S);
 }
 
+void StmtProfiler::VisitMatrixSingleSubscriptExpr(
+    const MatrixSingleSubscriptExpr *S) {
+  VisitExpr(S);
+}
+
 void StmtProfiler::VisitMatrixSubscriptExpr(const MatrixSubscriptExpr *S) {
   VisitExpr(S);
 }
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e842158236cd4..5eda28a297b81 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1796,6 +1796,8 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
     return EmitUnaryOpLValue(cast<UnaryOperator>(E));
   case Expr::ArraySubscriptExprClass:
     return EmitArraySubscriptExpr(cast<ArraySubscriptExpr>(E));
+  case Expr::MatrixSingleSubscriptExprClass:
+    return EmitMatrixSingleSubscriptExpr(cast<MatrixSingleSubscriptExpr>(E));
   case Expr::MatrixSubscriptExprClass:
     return EmitMatrixSubscriptExpr(cast<MatrixSubscriptExpr>(E));
   case Expr::ArraySectionExprClass:
@@ -2440,6 +2442,35 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
         Builder.CreateLoad(LV.getMatrixAddress(), LV.isVolatileQualified());
     return RValue::get(Builder.CreateExtractElement(Load, Idx, "matrixext"));
   }
+  if (LV.isMatrixRow()) {
+    QualType MatTy = LV.getType();
+    const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();
+
+    unsigned NumRows = MT->getNumRows();
+    unsigned NumCols = MT->getNumColumns();
+
+    llvm::Value *MatrixVec = EmitLoadOfScalar(LV, Loc);
+
+    llvm::Value *Row = LV.getMatrixRowIdx();
+    llvm::Value *Result =
+        llvm::UndefValue::get(ConvertType(LV.getType())); // <NumCols x T>
+
+    llvm::MatrixBuilder MB(Builder);
+
+    for (unsigned Col = 0; Col < NumCols; ++Col) {
+      llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
+
+      llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+
+      llvm::Value *Elt = Builder.CreateExtractElement(MatrixVec, EltIndex);
+
+      llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+
+      Result = Builder.CreateInsertElement(Result, Elt, Lane);
+    }
+
+    return RValue::get(Result);
+  }
 
   assert(LV.isBitField() && "Unknown LValue type!");
   return EmitLoadOfBitfieldLValue(LV, Loc);
@@ -2662,6 +2693,36 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
       addInstToCurrentSourceAtom(I, Vec);
       return;
     }
+    if (Dst.isMatrixRow()) {
+      QualType MatTy = Dst.getType();
+      const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();
+
+      unsigned NumRows = MT->getNumRows();
+      unsigned NumCols = MT->getNumColumns();
+
+      llvm::Value *MatrixVec =
+          Builder.CreateLoad(Dst.getAddress(), "matrix.load");
+
+      llvm::Value *Row = Dst.getMatrixRowIdx();
+      llvm::Value *RowVal = Src.getScalarVal(); // <NumCols x T>
+
+      llvm::MatrixBuilder MB(Builder);
+
+      for (unsigned Col = 0; Col < NumCols; ++Col) {
+        llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
+
+        llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+
+        llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+
+        llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane);
+
+        MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex);
+      }
+
+      Builder.CreateStore(MatrixVec, Dst.getAddress());
+      return;
+    }
 
     assert(Dst.isBitField() && "Unknown LValue type");
     return EmitStoreThroughBitfieldLValue(Src, Dst);
@@ -4874,6 +4935,16 @@ llvm::Value *CodeGenFunction::EmitMatrixIndexExpr(const Expr *E) {
   return Builder.CreateIntCast(Idx, IntPtrTy, IsSigned);
 }
 
+LValue CodeGenFunction::EmitMatrixSingleSubscriptExpr(
+    const MatrixSingleSubscriptExpr *E) {
+  LValue Base = EmitLValue(E->getBase());
+  llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
+
+  return LValue::MakeMatrixRow(
+      MaybeConvertMatrixAddress(Base.getAddress(), *this), RowIdx,
+      E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+}
+
 LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
   assert(
       !E->isIncomplete() &&
@@ -5146,6 +5217,9 @@ EmitExtVectorElementExpr(const ExtVectorElementExpr *E) {
     return LValue::MakeExtVectorElt(Base.getAddress(), CV, type,
                                     Base.getBaseInfo(), TBAAAccessInfo());
   }
+  if (Base.isMatrixRow())
+    return EmitUnsupportedLValue(E, "Matrix single index swizzle");
+
   assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!");
 
   llvm::Constant *BaseElts = Base.getExtVectorElts();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 769bc37b0e131..70397e8cb99c2 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -599,6 +599,7 @@ class ScalarExprEmitter
   }
 
   Value *VisitArraySubscriptExpr(ArraySubscriptExpr *E);
+  Value *VisitMatrixSingleSubscriptExpr(MatrixSingleSubscriptExpr *E);
   Value *VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E);
   Value *VisitShuffleVectorExpr(ShuffleVectorExpr *E);
   Value *VisitConvertVectorExpr(ConvertVectorExpr *E);
@@ -2109,6 +2110,40 @@ Value *ScalarExprEmitter::VisitArraySubscriptExpr(ArraySubscriptExpr *E) {
   return Builder.CreateExtractElement(Base, Idx, "vecext");
 }
 
+Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *E) {
+  TestAndClearIgnoreResultAssign();
+
+  auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+  unsigned NumRows = MatrixTy->getNumRows();
+  unsigned NumColumns = MatrixTy->getNumColumns();
+
+  // Row index
+  Value *RowIdx = CGF.EmitMatrixIndexExpr(E->getRowIdx());
+
+  llvm::MatrixBuilder MB(Builder);
+
+  // The row index must be in [0, NumRows)
+  if (CGF.CGM.getCodeGenOpts().OptimizationLevel > 0)
+    MB.CreateIndexAssumption(RowIdx, NumRows);
+
+  Value *FlatMatrix = Visit(E->getBase());
+  llvm::Type *ElemTy = CGF.ConvertType(MatrixTy->getElementType());
+  auto *ResultTy = llvm::FixedVectorType::get(ElemTy, NumColumns);
+  Value *RowVec = llvm::UndefValue::get(ResultTy);
+
+  for (unsigned Col = 0; Col != NumColumns; ++Col) {
+    Value *ColVal = llvm::ConstantInt::get(RowIdx->getType(), Col);
+    Value *EltIdx = MB.CreateIndex(RowIdx, ColVal, NumRows, "matrix_row_idx");
+    Value *Elt =
+        Builder.CreateExtractElement(FlatMatrix, EltIdx, "matrix_elem");
+    Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+    RowVec = Builder.CreateInsertElement(RowVec, Elt, Lane, "matrix_row_ins");
+  }
+
+  return RowVec;
+}
+
 Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
   TestAndClearIgnoreResultAssign();
 
diff --git a/clang/lib/CodeGen/CGValue.h b/clang/lib/CodeGen/CGValue.h
index 6b381b59e71cd..c08ca70de10e1 100644
--- a/clang/lib/CodeGen/CGValue.h
+++ b/clang/lib/CodeGen/CGValue.h
@@ -187,7 +187,8 @@ class LValue {
     BitField,     // This is a bitfield l-value, use getBitfield*.
     ExtVectorElt, // This is an extended vector subset, use getExtVectorComp
     GlobalReg,    // This is a register l-value, use getGlobalReg()
-    MatrixElt     // This is a matrix element, use getVector*
+    MatrixElt,    // This is a matrix element, use getVector*
+    MatrixRow     // This is a matrix vector subset, use getVector*
   } LVType;
 
   union {
@@ -282,6 +283,7 @@ class LValue {
   bool isExtVectorElt() const { return LVType == ExtVectorElt; }
   bool isGlobalReg() const { return LVType == GlobalReg; }
   bool isMatrixElt() const { return LVType == MatrixElt; }
+  bool isMatrixRow() const { return LVType == MatrixRow; }
 
   bool isVolatileQualified() const { return Quals.hasVolatile(); }
   bool isRestrictQualified() const { return Quals.hasRestrict(); }
@@ -398,6 +400,11 @@ class LValue {
     return VectorIdx;
   }
 
+  llvm::Value *getMatrixRowIdx() const {
+    assert(isMatrixRow());
+    return VectorIdx;
+  }
+
   // extended vector elements.
   Address getExtVectorAddress() const {
     assert(isExtVectorElt());
@@ -486,6 +493,16 @@ class LValue {
     return R;
   }
 
+  static LValue MakeMatrixRow(Address Addr, llvm::Value *RowIdx,
+                              QualType MatrixTy, LValueBaseInfo BaseInfo,
+                              TBAAAccessInfo TBAAInfo) {
+    LValue LV;
+    LV.LVType = MatrixRow;
+    LV.VectorIdx = RowIdx; // store the row index here
+    LV.Initialize(MatrixTy, MatrixTy.getQualifiers(), Addr, BaseInfo, TBAAInfo);
+    return LV;
+  }
+
   static LValue MakeMatrixElt(Address matAddress, llvm::Value *Idx,
                               QualType type, LValueBaseInfo BaseInfo,
                               TBAAAccessInfo TBAAInfo) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 8c4c1c8c2dc95..3abe516debcb0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4412,6 +4412,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   LValue EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
                                 bool Accessed = false);
   llvm::Value *EmitMatrixIndexExpr(const Expr *E);
+  LValue EmitMatrixSingleSubscriptExpr(const MatrixSingleSubscriptExpr *E);
   LValue EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E);
   LValue EmitArraySectionExpr(const ArraySectionExpr *E,
                               bool IsLowerBound = true);
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index a0483c3027199..8b7f2e0a755d7 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1303,6 +1303,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
     // Some might be dependent for other reasons.
   case Expr::ArraySubscriptExprClass:
   case Expr::MatrixSubscriptExprClass:
+  case Expr::MatrixSingleSubscriptExprClass:
   case Expr::ArraySectionExprClass:
   case Expr::OMPArrayShapingExprClass:
   case Expr::OMPIteratorExprClass:
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index cfabd1b76c103..fe82b8979baa7 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -5090,6 +5090,62 @@ ExprResult Sema::tryConvertExprToType(Expr *E, QualType Ty) {
   return InitSeq.Perform(*this, Entity, Kind, E);
 }
 
+ExprResult Sema::CreateBuiltinMatrixSingleSubscriptExpr(Expr *Base,
+                                                        Expr *RowIdx,
+                                                        SourceLocation RBLoc) {
+  ExprResult BaseR = CheckPlaceholderExpr(Base);
+  if (BaseR.isInvalid())
+    return BaseR;
+  Base = BaseR.get();
+
+  ExprResult RowR = CheckPlaceholderExpr(RowIdx);
+  if (RowR.isInvalid())
+    return RowR;
+  RowIdx = RowR.get();
+
+  // Build an unanalyzed expression if any of the operands is type-dependent.
+  if (Base->isTypeDependent() || RowIdx->isTypeDependent())
+    return new (Context)
+        MatrixSingleSubscriptExpr(Base, RowIdx, Context.DependentTy, RBLoc);
+
+  // Check that IndexExpr is an integer expression. If it is a constant
+  // expression, check that it is less than Dim (= the number of elements in the
+  // corresponding dimension).
+  auto IsIndexValid = [&](Expr *IndexExpr, unsigned Dim,
+                          bool IsColumnIdx) -> Expr * {
+    if (!IndexExpr->getType()->isIntegerType() &&
+        !IndexExpr->isTypeDependent()) {
+      Diag(IndexExpr->getBeginLoc(), diag::err_matrix_index_not_integer)
+          << IsColumnIdx;
+      return nullptr;
+    }
+
+    if (std::optional<llvm::APSInt> Idx =
+            IndexExpr->getIntegerConstantExpr(Context)) {
+      if ((*Idx < 0 || *Idx >= Dim)) {
+        Diag(IndexExpr->getBeginLoc(), diag::err_matrix_index_outside_range)
+            << IsColumnIdx << Dim;
+        return nullptr;
+      }
+    }
+
+    ExprResult ConvExpr = IndexExpr;
+    assert(!ConvExpr.isInvalid() &&
+           "should be able to convert any integer type to size type");
+    return ConvExpr.get();
+  };
+
+  auto *MTy = Base->getType()->getAs<ConstantMatrixType>();
+  RowIdx = IsIndexValid(RowIdx, MTy->getNumRows(), false);
+  if (!RowIdx)
+    return ExprError();
+
+  QualType RowVecQT =
+      Context.getExtVectorType(MTy->getElementType(), MTy->getNumColumns());
+
+  return new (Context) MatrixSingleSubscriptExpr(Base, RowIdx, RowVecQT, RBLoc);
+}
+
 ExprResult Sema::CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
                                                   Expr *ColumnIdx,
                                                   SourceLocation RBLoc) {
@@ -5103,9 +5159,12 @@ ExprResult Sema::CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
     return RowR;
   RowIdx = RowR.get();
 
-  if (!ColumnIdx)
+  if (!ColumnIdx) {
+    if (getLangOpts().HLSL)
+      return CreateBuiltinMatrixSingleSubscriptExpr(Base, RowIdx, RBLoc);
     return new (Context) MatrixSubscriptExpr(
         Base, RowIdx, ColumnIdx, Context.IncompleteMatrixIdxTy, RBLoc);
+  }
 
   // Build an unanalyzed expression if any of the operands is type-dependent.
   if (Base->isTypeDependent() || RowIdx->isTypeDependent() ||
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 8e5dbeb792348..6b8b2317d8aa4 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -2849,6 +2849,16 @@ class TreeTransform {
                                              RBracketLoc);
   }
 
+  /// Build a new matrix single subscript expression.
+  ///
+  /// By default, performs semantic analysis to build the new expression.
+  /// Subclasses may override this routine to provide different behavior.
+  ExprResult RebuildMatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx,
+                                              SourceLocation RBracketLoc) {
+    return getSema().CreateBuiltinMatrixSingleSubscriptExpr(Base, RowIdx,
+                                                            RBracketLoc);
+  }
+
   /// Build a new matrix subscript expression.
   ///
   /// By default, performs semantic analysis to build the new expression.
@@ -13378,6 +13388,25 @@ TreeTransform<Derived>::TransformArraySubscriptExpr(ArraySubscriptExpr *E) {
       /*FIXME:*/ E->getLHS()->getBeginLoc(), RHS.get(), E->getRBracketLoc());
 }
 
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *E) {
+  ExprResult Base = getDerived().TransformExpr(E->getBase());
+  if (Base.isInvalid())
+    return ExprError();
+
+  ExprResult RowIdx = getDerived().TransformExpr(E->getRowIdx());
+  if (RowIdx.isInvalid())
+    return ExprError();
+
+  if (!getDerived().AlwaysRebuild() && Base.get() == E->getBase() &&
+      RowIdx.get() == E->getRowIdx())
+    return E;
+
+  return getDerived().RebuildMatrixSingleSubscriptExpr(Base.get(), RowIdx.get(),
+                                                       E->getRBracketLoc());
+}
+
 template <typename Derived>
 ExprResult
 TreeTransform<Derived>::TransformMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index eef97a8588f0b..64fe19b2c660a 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -962,6 +962,14 @@ void ASTStmtReader::VisitArraySubscriptExpr(ArraySubscriptExpr *E) {
   E->setRBracketLoc(readSourceLocation());
 }
 
+void ASTStmtReader::VisitMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *E) {
+  VisitExpr(E);
+  E->setBase(Record.readSubExpr());
+  E->setRowIdx(Record.readSubExpr());
+  E->setRBracketLoc(readSourceLocation());
+}
+
 void ASTStmtReader::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
   VisitExpr(E);
   E->setBase(Record.readSubExpr());
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index acf345392aa1a..0f6c81f6a4693 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -900,6 +900,15 @@ void ASTStmtWriter::VisitArraySubscriptExpr(ArraySubscriptExpr *E) {
   Code = serialization::EXPR_ARRAY_SUBSCRIPT;
 }
 
+void ASTStmtWriter::VisitMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *E) {
+  VisitExpr(E);
+  Record.AddStmt(E->getBase());
+  Record.AddStmt(E->getRowIdx());
+  Record.AddSourceLocation(E->getRBracketLoc());
+  Code = serialization::EXPR_ARRAY_SUBSCRIPT;
+}
+
 void ASTStmtWriter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
   VisitExpr(E);
   Record.AddStmt(E->getBase());
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index a759aee47b8ea..7e99418997afd 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -2082,6 +2082,11 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
       Bldr.addNodes(Dst);
       break;
 
+    case Stmt::MatrixSingleSubscriptExprClass:
+      llvm_unreachable(
+          "Support for MatrixSingleSubscriptExprClass is not implemented.");
+      break;
+
     case Stmt::MatrixSubscriptExprClass:
       llvm_unreachable("Support for MatrixSubscriptExpr is not implemented.");
       break;
diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp
index 0a43d73063c1f..b865dd76011bc 100644
--- a/clang/tools/libclang/CXCursor.cpp
+++ b/clang/tools/libclang/CXCursor.cpp
@@ -424,6 +424,11 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent,
   case Stmt::ArraySubscriptExprClass:
     K = CXCursor_ArraySubscriptExpr;
     break;
+  
+  case Stmt::MatrixSingleSubscriptExprClass:
+    // TODO: add support for MatrixSingleSubscriptExpr.
+    K = CXCursor_UnexposedExpr;
+    break;
 
   case Stmt::MatrixSubscriptExprClass:
     // TODO: add support for MatrixSubscriptExpr.

From 1b7ad065781fcf772a58ffaad72a9d892d73a856 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <[email protected]>
Date: Thu, 4 Dec 2025 00:53:59 -0500
Subject: [PATCH 2/2] add swizzle support if row index is constant

---
 clang/lib/CodeGen/CGExpr.cpp      | 19 +++++++++++++++++++
 clang/tools/libclang/CXCursor.cpp |  2 +-
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 5eda28a297b81..ca06b5df94cb3 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -4940,6 +4940,25 @@ LValue CodeGenFunction::EmitMatrixSingleSubscriptExpr(
   LValue Base = EmitLValue(E->getBase());
   llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
 
+  if (auto *RowConst = llvm::dyn_cast<llvm::ConstantInt>(RowIdx)) {
+
+    // Extract matrix shape from the AST type
+    const auto *MatTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+    unsigned NumCols = MatTy->getNumColumns();
+    llvm::SmallVector<llvm::Constant *, 8> Indices;
+    Indices.reserve(NumCols);
+
+    unsigned Row = RowConst->getZExtValue();
+    unsigned Start = Row * NumCols;
+    for (unsigned C = 0; C < NumCols; ++C) {
+      Indices.push_back(llvm::ConstantInt::get(Int32Ty, Start + C));
+    }
+    llvm::Constant *Elts = llvm::ConstantVector::get(Indices);
+    return LValue::MakeExtVectorElt(
+        MaybeConvertMatrixAddress(Base.getAddress(), *this), Elts,
+        E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+  }
+
   return LValue::MakeMatrixRow(
       MaybeConvertMatrixAddress(Base.getAddress(), *this), RowIdx,
       E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp
index b865dd76011bc..1de1c18a2249f 100644
--- a/clang/tools/libclang/CXCursor.cpp
+++ b/clang/tools/libclang/CXCursor.cpp
@@ -424,7 +424,7 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent,
   case Stmt::ArraySubscriptExprClass:
     K = CXCursor_ArraySubscriptExpr;
     break;
-  
+
   case Stmt::MatrixSingleSubscriptExprClass:
     // TODO: add support for MatrixSingleSubscriptExpr.
     K = CXCursor_UnexposedExpr;

* <placeholder topic>
* <placeholder topic>
* <placeholder topic>
* <placeholder topic>