23 #include "llvm/ADT/TypeSwitch.h"
29 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
30 #include "mlir/Conversion/Passes.h.inc"
38 static bool isZeroConstant(
Value val) {
45 [](
auto floatAttr) {
return floatAttr.getValue().isZero(); })
47 [](
auto intAttr) {
return intAttr.getValue().isZero(); })
48 .Default([](
auto) {
return false; });
55 unsigned vecRank = vecTy.getRank();
56 if (!(vecRank == 1 || vecRank == 2))
63 VectorTransferOpInterface xferOp) {
66 "Masked transfer is not supported");
68 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
73 VectorType vecTy = xferOp.getVectorType();
74 if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
80 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
82 xferOp,
"Buffer must be contiguous in the innermost dimension");
84 unsigned vecRank = vecTy.getRank();
85 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
87 xferOp,
"Boundary check is available only for block instructions.");
94 auto dim = dyn_cast<AffineDimExpr>(expr);
95 if (dim.getPosition() < (numInputDims - vecRank))
97 xferOp,
"Only the innermost dimensions can be accessed");
103 static xegpu::CreateNdDescOp
107 MemRefType srcTy = src.getType();
108 auto [strides, offset] = srcTy.getStridesAndOffset();
110 xegpu::CreateNdDescOp ndDesc;
111 if (srcTy.hasStaticShape()) {
112 ndDesc = rewriter.
create<xegpu::CreateNdDescOp>(loc, descType, src,
118 unsigned srcRank = srcTy.getRank();
119 for (
unsigned i = 0; i < srcRank; ++i)
120 sourceDims.push_back(rewriter.
create<memref::DimOp>(loc, src, i));
124 for (
Value offset : offsets) {
127 dynOffsets.push_back(offset);
128 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
133 if (shape == ShapedType::kDynamic)
134 dynShapes.push_back(sourceDims[idx]);
139 Value accStride = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
141 for (
int i =
static_cast<int>(strides.size()) - 2; i >= 0; --i) {
143 rewriter.
create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
144 if (strides[i] == ShapedType::kDynamic)
145 dynStrides.push_back(accStride);
147 std::reverse(dynStrides.begin(), dynStrides.end());
149 ndDesc = rewriter.
create<xegpu::CreateNdDescOp>(
150 loc, descType, src, dynOffsets, dynShapes, dynStrides,
159 struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
162 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
169 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
170 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
172 readOp,
"Unsupported non-zero padded out-of-bounds read");
177 VectorType vecTy = readOp.getVectorType();
178 Type elementType = vecTy.getElementType();
179 unsigned minTransposeBitWidth = 32;
180 if (isTransposeLoad &&
183 readOp,
"Unsupported data type for transposition");
188 std::reverse(descShape.begin(), descShape.end());
190 descShape, elementType, 1,
191 isOutOfBounds, xegpu::MemorySpace::Global);
193 xegpu::CreateNdDescOp ndDesc =
194 createNdDescriptor(rewriter, loc, descType,
196 readOp.getIndices());
199 !isTransposeLoad ? nullptr
203 xegpu::CachePolicyAttr hint =
nullptr;
204 auto loadOp = rewriter.
create<xegpu::LoadNdOp>(
205 loc, vecTy, ndDesc,
nullptr, transposeAttr,
214 struct TransferWriteLowering
218 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
229 VectorType vecTy = writeOp.getVectorType();
231 vecTy.getShape(), vecTy.getElementType(),
232 1, writeOp.hasOutOfBoundsDim(),
233 xegpu::MemorySpace::Global);
234 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
235 rewriter, loc, descType,
237 writeOp.getIndices());
240 xegpu::CachePolicyAttr hint =
nullptr;
242 rewriter.
create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
254 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
258 VectorType vecTy = loadOp.getResult().getType();
259 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
263 bool boundaryCheck = vecTy.getRank() > 1;
266 vecTy.getShape(), vecTy.getElementType(), 1,
267 boundaryCheck, xegpu::MemorySpace::Global);
268 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
269 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
272 xegpu::CachePolicyAttr hint =
nullptr;
273 auto loadNdOp = rewriter.
create<xegpu::LoadNdOp>(
274 loc, vecTy, ndDesc,
nullptr,
nullptr,
286 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
291 VectorType vecTy = vector.getType();
292 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
296 bool boundaryCheck = vecTy.getRank() > 1;
299 vecTy.getShape(), vecTy.getElementType(),
300 1, boundaryCheck, xegpu::MemorySpace::Global);
301 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
302 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
305 xegpu::CachePolicyAttr hint =
nullptr;
307 rewriter.
create<xegpu::StoreNdOp>(loc, vector, ndDesc,
316 struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
319 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
323 if (contractOp.getKind() != vector::CombiningKind::ADD)
325 "Expects add combining kind");
328 VectorType accType = dyn_cast<VectorType>(acc.getType());
329 if (!accType || accType.getRank() != 2)
336 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
338 "Expects lhs and rhs 2D vectors");
344 auto accShape = accType.getShape();
345 int64_t dimN = accShape[1];
346 if (dimN != 8 && dimN != 16)
348 "Invalid operand dimensions");
350 auto dpasOp = rewriter.
create<xegpu::DpasOp>(
358 struct ConvertVectorToXeGPUPass
359 :
public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
360 void runOnOperation()
override {
364 return signalPassFailure();
372 patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
373 StoreLowering, ContractionLowering>(
patterns.getContext());
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool &requiresBroadcasting, VectorType &unbroadcastedVectorType)
This pattern supports lowering of: vector.transfer_read to a combination of vector....
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...