25#include "llvm/ADT/TypeSwitch.h"
31#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
32#include "mlir/Conversion/Passes.h.inc"
40static bool isZeroConstant(
Value val) {
46 .Case([](FloatAttr floatAttr) {
return floatAttr.getValue().isZero(); })
47 .Case([](IntegerAttr intAttr) {
return intAttr.getValue().isZero(); })
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
60 if (!vecTy.getElementType().isIntOrFloat())
62 op,
"Expected scalar type with known bitwidth");
68 if (!memTy.getElementType().isIntOrFloat())
70 op,
"Unsupported memref element type: expected integer or float");
76 VectorTransferOpInterface xferOp) {
79 "Masked transfer is not supported");
81 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
88 if (
failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
90 xferOp,
"Buffer must be contiguous in the innermost dimension");
92 VectorType vecTy = xferOp.getVectorType();
93 unsigned vecRank = vecTy.getRank();
94 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
96 xferOp,
"Boundary check is available only for block instructions.");
103 auto dim = dyn_cast<AffineDimExpr>(expr);
104 if (dim.getPosition() < (numInputDims - vecRank))
106 xferOp,
"Only the innermost dimensions can be accessed");
112static xegpu::CreateNdDescOp createNdDescriptor(
PatternRewriter &rewriter,
114 xegpu::TensorDescType descType,
116 MemRefType srcTy = src.getType();
117 assert(srcTy.isStrided() &&
"Expected strided memref type");
118 auto [strides, offset] = srcTy.getStridesAndOffset();
119 bool isStatic =
true;
122 if (!srcTy.hasStaticShape())
125 if (!ShapedType::isStatic(offset))
128 for (
auto stride : strides) {
129 if (!ShapedType::isStatic(stride)) {
135 xegpu::CreateNdDescOp ndDesc;
137 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
142 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
143 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
144 rewriter, loc, meta.getBaseBuffer());
145 auto offset = meta.getOffset();
146 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
147 auto offsetInBytes = arith::MulIOp::create(
148 rewriter, loc, offset,
150 auto adjustedBaseAddr = arith::AddIOp::create(
151 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
152 auto adjustedAddrI64 = arith::IndexCastOp::create(
153 rewriter, loc, rewriter.
getI64Type(), adjustedBaseAddr);
154 ndDesc = xegpu::CreateNdDescOp::create(
155 rewriter, loc, descType, adjustedAddrI64,
156 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
179static void adjustStridesForPermutation(
AffineMap permMap,
194 typename = std::enable_if_t<llvm::is_one_of<
195 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
196 vector::GatherOp, vector::ScatterOp>::value>>
197static std::pair<SmallVector<Value>,
Value>
200 Value baseMemref = xferOp.getBase();
201 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
204 Value offsetVal =
nullptr;
205 if (memrefType.hasStaticShape()) {
208 if (
failed(memrefType.getStridesAndOffset(intStrides, offset)))
209 return {{}, offsetVal};
210 bool hasDynamicStrides = llvm::any_of(intStrides, [](
int64_t strideVal) {
211 return ShapedType::isDynamic(strideVal);
214 if (!hasDynamicStrides)
218 if (!ShapedType::isDynamic(offset))
222 if (strides.empty() || !offsetVal) {
225 unsigned rank = memrefType.getRank();
231 resultTypes.push_back(MemRefType::get(
232 {}, memrefType.getElementType()));
233 resultTypes.push_back(indexType);
235 for (
unsigned i = 0; i < rank; ++i)
236 resultTypes.push_back(indexType);
238 for (
unsigned i = 0; i < rank; ++i)
239 resultTypes.push_back(indexType);
241 auto meta = memref::ExtractStridedMetadataOp::create(
242 rewriter, loc, resultTypes, baseMemref);
245 strides.append(meta.getStrides().begin(), meta.getStrides().end());
248 offsetVal = meta.getOffset();
251 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
252 vector::TransferWriteOp>::value) {
255 adjustStridesForPermutation(permMap, strides);
258 return {strides, offsetVal};
289static Value computeOffsets(VectorTransferOpInterface xferOp,
293 VectorType vectorType = xferOp.getVectorType();
295 xferOp.getIndices().end());
301 auto stepType = VectorType::get({dim}, rewriter.
getIndexType());
302 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
303 stepVectors.push_back(stepOp);
308 size_t memrefRank = strides.size();
311 for (
size_t i = 0; i < vectorRank; ++i) {
312 size_t memrefDim = memrefRank - vectorRank + i;
313 Value strideValue = strides[memrefDim];
314 auto mulType = dyn_cast<VectorType>(stepVectors[i].
getType());
316 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
317 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
318 strideMultiplied.push_back(mulOp);
323 for (
size_t i = 0; i < vectorRank; ++i) {
326 auto newType = VectorType::get(newShape, rewriter.
getIndexType());
327 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
328 strideMultiplied[i]);
329 shapeCasted.push_back(castOp);
334 auto fullIndexVectorType =
336 for (
Value shapeCastVal : shapeCasted) {
337 auto broadcastOp = vector::BroadcastOp::create(
338 rewriter, loc, fullIndexVectorType, shapeCastVal);
339 broadcasted.push_back(broadcastOp);
343 Value localOffsets = broadcasted[0];
344 for (
size_t i = 1; i < broadcasted.size(); ++i)
346 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
349 for (
size_t i = 0; i <
indices.size(); ++i) {
350 Value strideVal = strides[i];
351 Value offsetContrib =
352 arith::MulIOp::create(rewriter, loc,
indices[i], strideVal);
354 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
357 Value bcastBase = vector::BroadcastOp::create(
358 rewriter, loc, fullIndexVectorType, baseOffset);
359 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
370 typename = std::enable_if_t<llvm::is_one_of<
371 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
376 for (
size_t i = 0; i < offsets.size(); ++i) {
377 Value offsetContrib =
378 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
380 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
383 VectorType vecType = cast<VectorType>(
indices.getType());
386 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
388 Value stridedIndices =
389 arith::MulIOp::create(rewriter, loc, strideVector,
indices).getResult();
392 vector::BroadcastOp::create(
394 VectorType::get(vecType.getShape(), rewriter.
getIndexType()),
397 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
406static std::pair<Value, SmallVector<OpFoldResult>>
411 auto memrefType = cast<MemRefType>(
memref.getType());
412 unsigned rank = memrefType.getRank();
414 if (rank <= targetRank)
417 int64_t numCombinedDims = rank - targetRank;
423 for (
unsigned i = 0; i < numCombinedDims; ++i) {
424 subviewOffsets.push_back(offsets[i]);
431 auto originalShape = memrefType.getShape();
432 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc,
memref);
433 for (
unsigned i = numCombinedDims; i < rank; ++i) {
435 if (ShapedType::isDynamic(originalShape[i])) {
436 subviewSizes.push_back(meta.getSizes()[i]);
437 resultShape.push_back(ShapedType::kDynamic);
440 resultShape.push_back(originalShape[i]);
445 auto resultType = memref::SubViewOp::inferRankReducedResultType(
446 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
448 memref::SubViewOp::create(rewriter, loc, resultType,
memref,
449 subviewOffsets, subviewSizes, subviewStrides);
454 return {subviewOp.getResult(), newOffsets};
459 typename = std::enable_if_t<llvm::is_one_of<
460 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
461 vector::GatherOp, vector::ScatterOp>::value>>
465 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
466 rewriter, loc, xferOp.getBase())
468 return arith::IndexCastOp::create(rewriter, loc, rewriter.
getI64Type(),
473static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
477 VectorType vectorType = readOp.getVectorType();
479 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
483 auto meta = computeMemrefMeta(readOp, rewriter);
484 if (meta.first.empty())
488 computeOffsets(readOp, rewriter, meta.first, meta.second);
490 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
492 Value mask = vector::ConstantMaskOp::create(
495 auto gatherOp = xegpu::LoadGatherOp::create(
496 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
498 xegpu::CachePolicyAttr{},
499 xegpu::CachePolicyAttr{},
500 xegpu::CachePolicyAttr{},
503 rewriter.
replaceOp(readOp, gatherOp.getResult());
507static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
511 VectorType vectorType = writeOp.getVectorType();
514 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
518 auto meta = computeMemrefMeta(writeOp, rewriter);
519 if (meta.first.empty())
523 computeOffsets(writeOp, rewriter, meta.first, meta.second);
525 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
527 Value mask = vector::ConstantMaskOp::create(
530 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
533 xegpu::CachePolicyAttr{},
534 xegpu::CachePolicyAttr{},
535 xegpu::CachePolicyAttr{},
541struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
544 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
545 PatternRewriter &rewriter)
const override {
546 Location loc = readOp.getLoc();
548 if (
failed(transferPreconditions(rewriter, readOp)))
550 auto readMemTy = cast<MemRefType>(readOp.getShapedType());
551 VectorType loadedVecTy = readOp.getVectorType();
552 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
554 bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(readMemTy);
558 if (loadedVecTy.getRank() != 2)
560 readOp,
"Only 2D vector loads are supported for SLM");
561 AffineMap readMap = readOp.getPermutationMap();
565 "Non identity transposition is not supported for SLM loads.");
569 readOp,
"Out-of-bounds access is not supported for SLM loads");
573 xegpu::MemDescType::get(rewriter.
getContext(), readMemTy.getShape(),
574 readMemTy.getElementType(),
576 auto createMemDescOp = xegpu::CreateMemDescOp::create(
577 rewriter, loc, memDescType, readOp.getBase());
579 SmallVector<OpFoldResult>
indices =
581 auto loadMatrixOp = xegpu::LoadMatrixOp::create(
582 rewriter, loc, loadedVecTy, createMemDescOp.getResult(),
indices,
585 rewriter.
replaceOp(readOp, loadMatrixOp.getResult());
593 if ((chip !=
"pvc" && chip !=
"bmg" && chip !=
"cri") ||
594 readOp.getVectorType().getRank() > 2) {
599 return lowerToScatteredLoadOp(readOp, rewriter);
603 if (loadedVecTy.getRank() == 1 && !isOutOfBounds)
604 return lowerToScatteredLoadOp(readOp, rewriter);
609 storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
612 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
614 readOp,
"Unsupported non-zero padded out-of-bounds read");
623 bool isTransposeLoad =
false;
627 if (numInputs >= 2) {
632 (results[0] == lastDim && results[1] == secondLastDim);
635 auto elementType = loadedVecTy.getElementType();
637 SmallVector<int64_t> descShape(loadedVecTy.getShape());
638 if (isTransposeLoad) {
641 size_t rank = descShape.size();
642 assert(rank >= 2 &&
"Transpose requires at least 2 dimensions");
643 std::swap(descShape[rank - 1], descShape[rank - 2]);
644 loadedVecTy = VectorType::get(descShape, elementType);
646 auto descType = xegpu::TensorDescType::get(
647 descShape, elementType, 1,
648 isOutOfBounds, xegpu::MemorySpace::Global);
649 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
651 loadedVecTy.getRank());
653 xegpu::CachePolicyAttr hint =
nullptr;
654 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
657 Operation *loadedOp =
658 xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc,
indices,
663 if (isTransposeLoad) {
666 auto range = llvm::seq<int64_t>(0, readMap.
getResults().size());
667 SmallVector<int64_t> perm(
668 range.rbegin(), range.rend());
669 loadedOp = vector::TransposeOp::create(rewriter, loc,
678struct TransferWriteLowering
682 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
683 PatternRewriter &rewriter)
const override {
684 Location loc = writeOp.getLoc();
686 if (
failed(transferPreconditions(rewriter, writeOp)))
689 VectorType vecTy = writeOp.getVectorType();
690 auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
692 bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(writeMemTy);
698 if (vecTy.getRank() != 2)
700 writeOp,
"Only 2D vector stores are supported for SLM");
703 xegpu::MemDescType::get(rewriter.
getContext(), writeMemTy.getShape(),
704 writeMemTy.getElementType(),
707 auto createMemDescOp = xegpu::CreateMemDescOp::create(
708 rewriter, loc, memDescType, writeOp.getBase());
711 SmallVector<OpFoldResult>
indices =
714 xegpu::StoreMatrixOp::create(rewriter, loc, writeOp.getVector(),
715 createMemDescOp.getResult(),
indices,
726 if ((chip !=
"pvc" && chip !=
"bmg" && chip !=
"cri") ||
727 writeOp.getVectorType().getRank() > 2) {
730 if (writeOp.hasOutOfBoundsDim())
732 return lowerToScatteredStoreOp(writeOp, rewriter);
735 if (
failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
742 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
743 rewriter, loc, writeOp.getBase(),
746 auto descType = xegpu::TensorDescType::get(
747 vecTy.getShape(), vecTy.getElementType(),
748 1, writeOp.hasOutOfBoundsDim(),
749 xegpu::MemorySpace::Global);
751 xegpu::CachePolicyAttr hint =
nullptr;
752 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
755 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
769 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
770 PatternRewriter &rewriter)
const override {
771 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
775 Location loc = gatherOp.getLoc();
776 VectorType vectorType = gatherOp.getVectorType();
778 auto meta = computeMemrefMeta(gatherOp, rewriter);
779 if (meta.first.empty())
783 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
784 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
786 auto xeGatherOp = xegpu::LoadGatherOp::create(
787 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
789 xegpu::CachePolicyAttr{},
790 xegpu::CachePolicyAttr{},
791 xegpu::CachePolicyAttr{},
795 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
796 xeGatherOp.getResult(), gatherOp.getPassThru());
797 rewriter.
replaceOp(gatherOp, selectOp.getResult());
805 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
806 PatternRewriter &rewriter)
const override {
807 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
811 Location loc = scatterOp.getLoc();
812 auto meta = computeMemrefMeta(scatterOp, rewriter);
813 if (meta.first.empty())
815 "Failed to compute strides");
818 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
819 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
821 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
822 flatMemref, localOffsets, scatterOp.getMask(),
824 xegpu::CachePolicyAttr{},
825 xegpu::CachePolicyAttr{},
826 xegpu::CachePolicyAttr{},
836 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
837 PatternRewriter &rewriter)
const override {
838 Location loc = loadOp.getLoc();
840 VectorType vecTy = loadOp.getResult().getType();
841 MemRefType memTy = loadOp.getBase().getType();
842 if (
failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
846 bool boundaryCheck = vecTy.getRank() > 1;
848 xegpu::CachePolicyAttr hint =
nullptr;
850 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
854 auto descType = xegpu::TensorDescType::get(
855 vecTy.getShape(), vecTy.getElementType(), 1,
856 boundaryCheck, xegpu::MemorySpace::Global);
858 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
861 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
875 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
876 PatternRewriter &rewriter)
const override {
877 Location loc = storeOp.getLoc();
880 VectorType vecTy = vector.getType();
881 MemRefType memTy = storeOp.getBase().getType();
882 if (
failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy)))
886 bool boundaryCheck = vecTy.getRank() > 1;
888 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
889 rewriter, loc, storeOp.getBase(),
892 auto descType = xegpu::TensorDescType::get(
893 vecTy.getShape(), vecTy.getElementType(),
894 1, boundaryCheck, xegpu::MemorySpace::Global);
897 xegpu::CachePolicyAttr hint =
nullptr;
898 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
902 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
indices,
913struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
916 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
917 PatternRewriter &rewriter)
const override {
918 Location loc = contractOp.getLoc();
920 if (contractOp.getKind() != vector::CombiningKind::ADD)
922 "Expects add combining kind");
925 VectorType accType = dyn_cast<VectorType>(acc.getType());
926 if (!accType || accType.getRank() != 2)
933 if (
lhs.getType().getRank() != 2 ||
rhs.getType().getRank() != 2)
935 "Expects lhs and rhs 2D vectors");
940 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
949struct ConvertVectorToXeGPUPass
951 void runOnOperation()
override {
956 return signalPassFailure();
965 .
add<TransferReadLowering, TransferWriteLowering, LoadLowering,
966 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
static std::optional< VectorShape > vectorShape(Type type)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...