25#include "llvm/ADT/TypeSwitch.h"
31#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
32#include "mlir/Conversion/Passes.h.inc"
40static bool isZeroConstant(
Value val) {
46 .Case([](FloatAttr floatAttr) {
return floatAttr.getValue().isZero(); })
47 .Case([](IntegerAttr intAttr) {
return intAttr.getValue().isZero(); })
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
60 if (!vecTy.getElementType().isIntOrFloat())
62 op,
"Expected scalar type with known bitwidth");
68 if (!memTy.getElementType().isIntOrFloat())
70 op,
"Unsupported memref element type: expected integer or float");
76 VectorTransferOpInterface xferOp) {
79 "Masked transfer is not supported");
81 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
88 if (
failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
90 xferOp,
"Buffer must be contiguous in the innermost dimension");
92 VectorType vecTy = xferOp.getVectorType();
93 unsigned vecRank = vecTy.getRank();
94 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
96 xferOp,
"Boundary check is available only for block instructions.");
103 auto dim = dyn_cast<AffineDimExpr>(expr);
104 if (dim.getPosition() < (numInputDims - vecRank))
106 xferOp,
"Only the innermost dimensions can be accessed");
112static xegpu::CreateNdDescOp createNdDescriptor(
PatternRewriter &rewriter,
114 xegpu::TensorDescType descType,
116 MemRefType srcTy = src.getType();
117 assert(srcTy.isStrided() &&
"Expected strided memref type");
118 auto [strides, offset] = srcTy.getStridesAndOffset();
119 bool isStatic =
true;
122 if (!srcTy.hasStaticShape())
125 if (!ShapedType::isStatic(offset))
128 for (
auto stride : strides) {
129 if (!ShapedType::isStatic(stride)) {
135 xegpu::CreateNdDescOp ndDesc;
137 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
142 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
143 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
144 rewriter, loc, meta.getBaseBuffer());
145 auto offset = meta.getOffset();
146 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
147 auto offsetInBytes = arith::MulIOp::create(
148 rewriter, loc, offset,
150 auto adjustedBaseAddr = arith::AddIOp::create(
151 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
152 auto adjustedAddrI64 = arith::IndexCastOp::create(
153 rewriter, loc, rewriter.
getI64Type(), adjustedBaseAddr);
154 ndDesc = xegpu::CreateNdDescOp::create(
155 rewriter, loc, descType, adjustedAddrI64,
156 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
179static void adjustStridesForPermutation(
AffineMap permMap,
194 typename = std::enable_if_t<llvm::is_one_of<
195 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
196 vector::GatherOp, vector::ScatterOp>::value>>
197static std::pair<SmallVector<Value>,
Value>
200 Value baseMemref = xferOp.getBase();
201 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
204 Value offsetVal =
nullptr;
205 if (memrefType.hasStaticShape()) {
208 if (
failed(memrefType.getStridesAndOffset(intStrides, offset)))
209 return {{}, offsetVal};
210 bool hasDynamicStrides = llvm::any_of(intStrides, [](
int64_t strideVal) {
211 return ShapedType::isDynamic(strideVal);
214 if (!hasDynamicStrides)
218 if (!ShapedType::isDynamic(offset))
222 if (strides.empty() || !offsetVal) {
225 unsigned rank = memrefType.getRank();
231 resultTypes.push_back(MemRefType::get(
232 {}, memrefType.getElementType()));
233 resultTypes.push_back(indexType);
235 for (
unsigned i = 0; i < rank; ++i)
236 resultTypes.push_back(indexType);
238 for (
unsigned i = 0; i < rank; ++i)
239 resultTypes.push_back(indexType);
241 auto meta = memref::ExtractStridedMetadataOp::create(
242 rewriter, loc, resultTypes, baseMemref);
245 strides.append(meta.getStrides().begin(), meta.getStrides().end());
248 offsetVal = meta.getOffset();
251 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
252 vector::TransferWriteOp>::value) {
255 adjustStridesForPermutation(permMap, strides);
258 return {strides, offsetVal};
289static Value computeOffsets(VectorTransferOpInterface xferOp,
293 VectorType vectorType = xferOp.getVectorType();
295 xferOp.getIndices().end());
301 auto stepType = VectorType::get({dim}, rewriter.
getIndexType());
302 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
303 stepVectors.push_back(stepOp);
308 size_t memrefRank = strides.size();
311 for (
size_t i = 0; i < vectorRank; ++i) {
312 size_t memrefDim = memrefRank - vectorRank + i;
313 Value strideValue = strides[memrefDim];
314 auto mulType = dyn_cast<VectorType>(stepVectors[i].
getType());
316 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
317 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
318 strideMultiplied.push_back(mulOp);
323 for (
size_t i = 0; i < vectorRank; ++i) {
326 auto newType = VectorType::get(newShape, rewriter.
getIndexType());
327 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
328 strideMultiplied[i]);
329 shapeCasted.push_back(castOp);
334 auto fullIndexVectorType =
336 for (
Value shapeCastVal : shapeCasted) {
337 auto broadcastOp = vector::BroadcastOp::create(
338 rewriter, loc, fullIndexVectorType, shapeCastVal);
339 broadcasted.push_back(broadcastOp);
343 Value localOffsets = broadcasted[0];
344 for (
size_t i = 1; i < broadcasted.size(); ++i)
346 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
349 for (
size_t i = 0; i <
indices.size(); ++i) {
350 Value strideVal = strides[i];
351 Value offsetContrib =
352 arith::MulIOp::create(rewriter, loc,
indices[i], strideVal);
354 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
357 Value bcastBase = vector::BroadcastOp::create(
358 rewriter, loc, fullIndexVectorType, baseOffset);
359 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
370 typename = std::enable_if_t<llvm::is_one_of<
371 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
376 for (
size_t i = 0; i < offsets.size(); ++i) {
377 Value offsetContrib =
378 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
380 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
383 VectorType vecType = cast<VectorType>(
indices.getType());
386 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
388 Value stridedIndices =
389 arith::MulIOp::create(rewriter, loc, strideVector,
indices).getResult();
392 vector::BroadcastOp::create(
394 VectorType::get(vecType.getShape(), rewriter.
getIndexType()),
397 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
406static std::pair<Value, SmallVector<OpFoldResult>>
411 auto memrefType = cast<MemRefType>(
memref.getType());
412 unsigned rank = memrefType.getRank();
414 if (rank <= targetRank)
417 int64_t numCombinedDims = rank - targetRank;
423 for (
unsigned i = 0; i < numCombinedDims; ++i) {
424 subviewOffsets.push_back(offsets[i]);
431 auto originalShape = memrefType.getShape();
432 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc,
memref);
433 for (
unsigned i = numCombinedDims; i < rank; ++i) {
435 if (ShapedType::isDynamic(originalShape[i])) {
436 subviewSizes.push_back(meta.getSizes()[i]);
437 resultShape.push_back(ShapedType::kDynamic);
440 resultShape.push_back(originalShape[i]);
445 auto resultType = memref::SubViewOp::inferRankReducedResultType(
446 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
448 memref::SubViewOp::create(rewriter, loc, resultType,
memref,
449 subviewOffsets, subviewSizes, subviewStrides);
454 return {subviewOp.getResult(), newOffsets};
459 typename = std::enable_if_t<llvm::is_one_of<
460 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
461 vector::GatherOp, vector::ScatterOp>::value>>
465 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
466 rewriter, loc, xferOp.getBase())
468 return arith::IndexCastOp::create(rewriter, loc, rewriter.
getI64Type(),
473static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
477 VectorType vectorType = readOp.getVectorType();
479 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
483 auto meta = computeMemrefMeta(readOp, rewriter);
484 if (meta.first.empty())
488 computeOffsets(readOp, rewriter, meta.first, meta.second);
490 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
492 Value mask = vector::ConstantMaskOp::create(
495 auto gatherOp = xegpu::LoadGatherOp::create(
496 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
498 xegpu::CachePolicyAttr{},
499 xegpu::CachePolicyAttr{},
500 xegpu::CachePolicyAttr{},
503 rewriter.
replaceOp(readOp, gatherOp.getResult());
507static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
511 VectorType vectorType = writeOp.getVectorType();
514 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
518 auto meta = computeMemrefMeta(writeOp, rewriter);
519 if (meta.first.empty())
523 computeOffsets(writeOp, rewriter, meta.first, meta.second);
525 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
527 Value mask = vector::ConstantMaskOp::create(
530 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
533 xegpu::CachePolicyAttr{},
534 xegpu::CachePolicyAttr{},
535 xegpu::CachePolicyAttr{},
541struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
544 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
545 PatternRewriter &rewriter)
const override {
546 Location loc = readOp.getLoc();
548 if (
failed(transferPreconditions(rewriter, readOp)))
553 if (chip !=
"pvc" && chip !=
"bmg") {
557 if (readOp.hasOutOfBoundsDim())
559 return lowerToScatteredLoadOp(readOp, rewriter);
562 VectorType loadedVecTy = readOp.getVectorType();
565 if (loadedVecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
566 return lowerToScatteredLoadOp(readOp, rewriter);
569 auto readMemTy = cast<MemRefType>(readOp.getShapedType());
571 storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
574 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
575 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
577 readOp,
"Unsupported non-zero padded out-of-bounds read");
579 AffineMap readMap = readOp.getPermutationMap();
581 auto elementType = loadedVecTy.getElementType();
583 SmallVector<int64_t> descShape(loadedVecTy.getShape());
584 if (isTransposeLoad) {
590 loadedVecTy = VectorType::get(descShape, elementType);
592 auto descType = xegpu::TensorDescType::get(
593 descShape, elementType, 1,
594 isOutOfBounds, xegpu::MemorySpace::Global);
595 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
597 loadedVecTy.getRank());
599 xegpu::CachePolicyAttr hint =
nullptr;
600 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
603 Operation *loadedOp =
604 xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc,
indices,
609 if (isTransposeLoad) {
612 auto range = llvm::seq<int64_t>(0, readMap.
getResults().size());
613 SmallVector<int64_t> perm(range.begin(), range.end());
615 loadedOp = vector::TransposeOp::create(
616 rewriter, loc, loadedOp->
getResult(0), permApplied);
624struct TransferWriteLowering
628 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
629 PatternRewriter &rewriter)
const override {
630 Location loc = writeOp.getLoc();
632 if (
failed(transferPreconditions(rewriter, writeOp)))
637 if (chip !=
"pvc" && chip !=
"bmg") {
641 if (writeOp.hasOutOfBoundsDim())
643 return lowerToScatteredStoreOp(writeOp, rewriter);
647 VectorType vecTy = writeOp.getVectorType();
648 auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
649 if (
failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
656 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
657 rewriter, loc, writeOp.getBase(),
660 auto descType = xegpu::TensorDescType::get(
661 vecTy.getShape(), vecTy.getElementType(),
662 1, writeOp.hasOutOfBoundsDim(),
663 xegpu::MemorySpace::Global);
665 xegpu::CachePolicyAttr hint =
nullptr;
666 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
669 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
683 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
684 PatternRewriter &rewriter)
const override {
685 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
689 Location loc = gatherOp.getLoc();
690 VectorType vectorType = gatherOp.getVectorType();
692 auto meta = computeMemrefMeta(gatherOp, rewriter);
693 if (meta.first.empty())
697 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
698 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
700 auto xeGatherOp = xegpu::LoadGatherOp::create(
701 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
703 xegpu::CachePolicyAttr{},
704 xegpu::CachePolicyAttr{},
705 xegpu::CachePolicyAttr{},
709 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
710 xeGatherOp.getResult(), gatherOp.getPassThru());
711 rewriter.
replaceOp(gatherOp, selectOp.getResult());
719 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
720 PatternRewriter &rewriter)
const override {
721 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
725 Location loc = scatterOp.getLoc();
726 auto meta = computeMemrefMeta(scatterOp, rewriter);
727 if (meta.first.empty())
729 "Failed to compute strides");
732 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
733 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
735 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
736 flatMemref, localOffsets, scatterOp.getMask(),
738 xegpu::CachePolicyAttr{},
739 xegpu::CachePolicyAttr{},
740 xegpu::CachePolicyAttr{},
750 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
751 PatternRewriter &rewriter)
const override {
752 Location loc = loadOp.getLoc();
754 VectorType vecTy = loadOp.getResult().getType();
755 MemRefType memTy = loadOp.getBase().getType();
756 if (
failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
760 bool boundaryCheck = vecTy.getRank() > 1;
762 xegpu::CachePolicyAttr hint =
nullptr;
764 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
768 auto descType = xegpu::TensorDescType::get(
769 vecTy.getShape(), vecTy.getElementType(), 1,
770 boundaryCheck, xegpu::MemorySpace::Global);
772 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
775 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
789 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
790 PatternRewriter &rewriter)
const override {
791 Location loc = storeOp.getLoc();
794 VectorType vecTy = vector.getType();
795 MemRefType memTy = storeOp.getBase().getType();
796 if (
failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy)))
800 bool boundaryCheck = vecTy.getRank() > 1;
802 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
803 rewriter, loc, storeOp.getBase(),
806 auto descType = xegpu::TensorDescType::get(
807 vecTy.getShape(), vecTy.getElementType(),
808 1, boundaryCheck, xegpu::MemorySpace::Global);
811 xegpu::CachePolicyAttr hint =
nullptr;
812 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
816 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
indices,
827struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
830 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
831 PatternRewriter &rewriter)
const override {
832 Location loc = contractOp.getLoc();
834 if (contractOp.getKind() != vector::CombiningKind::ADD)
836 "Expects add combining kind");
839 VectorType accType = dyn_cast<VectorType>(acc.getType());
840 if (!accType || accType.getRank() != 2)
847 if (
lhs.getType().getRank() != 2 ||
rhs.getType().getRank() != 2)
849 "Expects lhs and rhs 2D vectors");
854 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
863struct ConvertVectorToXeGPUPass
864 :
public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
865 void runOnOperation()
override {
870 return signalPassFailure();
879 .
add<TransferReadLowering, TransferWriteLowering, LoadLowering,
880 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...