24 #include "llvm/ADT/TypeSwitch.h"
30 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31 #include "mlir/Conversion/Passes.h.inc"
39 static bool isZeroConstant(
Value val) {
46 [](
auto floatAttr) {
return floatAttr.getValue().isZero(); })
48 [](
auto intAttr) {
return intAttr.getValue().isZero(); })
49 .Default([](
auto) {
return false; });
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
64 VectorTransferOpInterface xferOp) {
67 "Masked transfer is not supported");
69 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
76 if (
failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
78 xferOp,
"Buffer must be contiguous in the innermost dimension");
80 VectorType vecTy = xferOp.getVectorType();
81 unsigned vecRank = vecTy.getRank();
82 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
84 xferOp,
"Boundary check is available only for block instructions.");
91 auto dim = dyn_cast<AffineDimExpr>(expr);
92 if (dim.getPosition() < (numInputDims - vecRank))
94 xferOp,
"Only the innermost dimensions can be accessed");
100 static xegpu::CreateNdDescOp
104 MemRefType srcTy = src.getType();
105 auto [strides, offset] = srcTy.getStridesAndOffset();
107 xegpu::CreateNdDescOp ndDesc;
108 if (srcTy.hasStaticShape()) {
109 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
115 unsigned srcRank = srcTy.getRank();
116 for (
unsigned i = 0; i < srcRank; ++i)
117 sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
121 for (
Value offset : offsets) {
124 dynOffsets.push_back(offset);
125 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
130 if (shape == ShapedType::kDynamic)
131 dynShapes.push_back(sourceDims[idx]);
138 for (
int i =
static_cast<int>(strides.size()) - 2; i >= 0; --i) {
140 arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
141 if (strides[i] == ShapedType::kDynamic)
142 dynStrides.push_back(accStride);
144 std::reverse(dynStrides.begin(), dynStrides.end());
146 ndDesc = xegpu::CreateNdDescOp::create(
147 rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
173 static void adjustStridesForPermutation(
AffineMap permMap,
188 typename = std::enable_if_t<llvm::is_one_of<
189 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
190 vector::GatherOp, vector::ScatterOp>::value>>
191 static std::pair<SmallVector<Value>,
Value>
194 Value baseMemref = xferOp.getBase();
195 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
198 Value offsetVal =
nullptr;
199 if (memrefType.hasStaticShape()) {
202 if (
failed(memrefType.getStridesAndOffset(intStrides, offset)))
203 return {{}, offsetVal};
204 bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
205 return ShapedType::isDynamic(strideVal);
208 if (!hasDynamicStrides)
209 for (int64_t s : intStrides)
212 if (!ShapedType::isDynamic(offset))
216 if (strides.empty() || !offsetVal) {
219 unsigned rank = memrefType.getRank();
226 {}, memrefType.getElementType()));
227 resultTypes.push_back(indexType);
229 for (
unsigned i = 0; i < rank; ++i)
230 resultTypes.push_back(indexType);
232 for (
unsigned i = 0; i < rank; ++i)
233 resultTypes.push_back(indexType);
235 auto meta = memref::ExtractStridedMetadataOp::create(
236 rewriter, loc, resultTypes, baseMemref);
239 strides.append(meta.getStrides().begin(), meta.getStrides().end());
242 offsetVal = meta.getOffset();
245 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
246 vector::TransferWriteOp>::value) {
249 adjustStridesForPermutation(permMap, strides);
252 return {strides, offsetVal};
283 static Value computeOffsets(VectorTransferOpInterface xferOp,
287 VectorType vectorType = xferOp.getVectorType();
289 xferOp.getIndices().end());
294 llvm::map_to_vector(
vectorShape, [&](int64_t dim) {
296 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
297 stepVectors.push_back(stepOp);
302 size_t memrefRank = strides.size();
305 for (
size_t i = 0; i < vectorRank; ++i) {
306 size_t memrefDim = memrefRank - vectorRank + i;
307 Value strideValue = strides[memrefDim];
308 auto mulType = dyn_cast<VectorType>(stepVectors[i].
getType());
310 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
311 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
312 strideMultiplied.push_back(mulOp);
317 for (
size_t i = 0; i < vectorRank; ++i) {
321 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
322 strideMultiplied[i]);
323 shapeCasted.push_back(castOp);
328 auto fullIndexVectorType =
330 for (
Value shapeCastVal : shapeCasted) {
331 auto broadcastOp = vector::BroadcastOp::create(
332 rewriter, loc, fullIndexVectorType, shapeCastVal);
333 broadcasted.push_back(broadcastOp);
337 Value localOffsets = broadcasted[0];
338 for (
size_t i = 1; i < broadcasted.size(); ++i)
340 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
343 for (
size_t i = 0; i < indices.size(); ++i) {
344 Value strideVal = strides[i];
345 Value offsetContrib =
346 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
348 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
351 Value bcastBase = vector::BroadcastOp::create(
352 rewriter, loc, fullIndexVectorType, baseOffset);
353 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
364 typename = std::enable_if_t<llvm::is_one_of<
365 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
370 for (
size_t i = 0; i < offsets.size(); ++i) {
371 Value offsetContrib =
372 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
374 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
376 Value indices = gatScatOp.getIndices();
377 VectorType vecType = cast<VectorType>(indices.
getType());
380 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
382 Value stridedIndices =
383 arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
386 vector::BroadcastOp::create(
391 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
397 typename = std::enable_if_t<llvm::is_one_of<
398 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
399 vector::GatherOp, vector::ScatterOp>::value>>
403 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
404 rewriter, loc, xferOp.getBase())
406 return arith::IndexCastOp::create(rewriter, loc, rewriter.
getI64Type(),
411 static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
415 VectorType vectorType = readOp.getVectorType();
417 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
421 auto meta = computeMemrefMeta(readOp, rewriter);
422 if (meta.first.empty())
426 computeOffsets(readOp, rewriter, meta.first, meta.second);
428 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
430 Value mask = vector::ConstantMaskOp::create(
433 auto gatherOp = xegpu::LoadGatherOp::create(
434 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
436 xegpu::CachePolicyAttr{},
437 xegpu::CachePolicyAttr{},
438 xegpu::CachePolicyAttr{});
440 rewriter.
replaceOp(readOp, gatherOp.getResult());
444 static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
448 VectorType vectorType = writeOp.getVectorType();
451 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
455 auto meta = computeMemrefMeta(writeOp, rewriter);
456 if (meta.first.empty())
460 computeOffsets(writeOp, rewriter, meta.first, meta.second);
462 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
464 Value mask = vector::ConstantMaskOp::create(
467 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
470 xegpu::CachePolicyAttr{},
471 xegpu::CachePolicyAttr{},
472 xegpu::CachePolicyAttr{});
477 struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
480 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
484 if (
failed(transferPreconditions(rewriter, readOp)))
489 if (chip !=
"pvc" && chip !=
"bmg") {
493 if (readOp.hasOutOfBoundsDim())
495 return lowerToScatteredLoadOp(readOp, rewriter);
499 VectorType vecTy = readOp.getVectorType();
500 if (
failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
503 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
504 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
506 readOp,
"Unsupported non-zero padded out-of-bounds read");
511 Type elementType = vecTy.getElementType();
512 unsigned minTransposeBitWidth = 32;
513 if (isTransposeLoad &&
516 readOp,
"Unsupported data type for transposition");
521 std::reverse(descShape.begin(), descShape.end());
523 descShape, elementType, 1,
524 isOutOfBounds, xegpu::MemorySpace::Global);
526 xegpu::CreateNdDescOp ndDesc =
527 createNdDescriptor(rewriter, loc, descType,
529 readOp.getIndices());
532 !isTransposeLoad ? nullptr
536 xegpu::CachePolicyAttr hint =
nullptr;
537 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
538 nullptr, transposeAttr,
547 struct TransferWriteLowering
551 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
555 if (
failed(transferPreconditions(rewriter, writeOp)))
560 if (chip !=
"pvc" && chip !=
"bmg") {
564 if (writeOp.hasOutOfBoundsDim())
566 return lowerToScatteredStoreOp(writeOp, rewriter);
570 VectorType vecTy = writeOp.getVectorType();
571 if (
failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
579 vecTy.getShape(), vecTy.getElementType(),
580 1, writeOp.hasOutOfBoundsDim(),
581 xegpu::MemorySpace::Global);
582 xegpu::CreateNdDescOp ndDesc =
583 createNdDescriptor(rewriter, loc, descType,
585 writeOp.getIndices());
588 xegpu::CachePolicyAttr hint =
nullptr;
590 xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
602 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
604 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
609 VectorType vectorType = gatherOp.getVectorType();
611 auto meta = computeMemrefMeta(gatherOp, rewriter);
612 if (meta.first.empty())
616 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
617 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
619 auto xeGatherOp = xegpu::LoadGatherOp::create(
620 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
622 xegpu::CachePolicyAttr{},
623 xegpu::CachePolicyAttr{},
624 xegpu::CachePolicyAttr{});
627 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
628 xeGatherOp.getResult(), gatherOp.getPassThru());
629 rewriter.
replaceOp(gatherOp, selectOp.getResult());
637 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
639 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
644 auto meta = computeMemrefMeta(scatterOp, rewriter);
645 if (meta.first.empty())
647 "Failed to compute strides");
650 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
651 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
653 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
654 flatMemref, localOffsets, scatterOp.getMask(),
656 xegpu::CachePolicyAttr{},
657 xegpu::CachePolicyAttr{},
658 xegpu::CachePolicyAttr{});
667 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
671 VectorType vecTy = loadOp.getResult().getType();
672 if (
failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
676 bool boundaryCheck = vecTy.getRank() > 1;
679 vecTy.getShape(), vecTy.getElementType(), 1,
680 boundaryCheck, xegpu::MemorySpace::Global);
681 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
682 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
685 xegpu::CachePolicyAttr hint =
nullptr;
686 auto loadNdOp = xegpu::LoadNdOp::create(
687 rewriter, loc, vecTy, ndDesc,
nullptr,
nullptr,
699 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
704 VectorType vecTy = vector.getType();
705 if (
failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
709 bool boundaryCheck = vecTy.getRank() > 1;
712 vecTy.getShape(), vecTy.getElementType(),
713 1, boundaryCheck, xegpu::MemorySpace::Global);
714 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
715 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
718 xegpu::CachePolicyAttr hint =
nullptr;
720 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
729 struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
732 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
736 if (contractOp.getKind() != vector::CombiningKind::ADD)
738 "Expects add combining kind");
741 VectorType accType = dyn_cast<VectorType>(acc.getType());
742 if (!accType || accType.getRank() != 2)
749 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
751 "Expects lhs and rhs 2D vectors");
756 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
765 struct ConvertVectorToXeGPUPass
766 :
public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
767 void runOnOperation()
override {
771 return signalPassFailure();
780 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
781 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...