26#include "llvm/ADT/TypeSwitch.h"
32#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
33#include "mlir/Conversion/Passes.h.inc"
41static bool isZeroConstant(
Value val) {
47 .Case([](FloatAttr floatAttr) {
return floatAttr.getValue().isZero(); })
48 .Case([](IntegerAttr intAttr) {
return intAttr.getValue().isZero(); })
57 unsigned vecRank = vecTy.getRank();
58 if (!(vecRank == 1 || vecRank == 2))
61 if (!vecTy.getElementType().isIntOrFloat())
63 op,
"Expected scalar type with known bitwidth");
69 if (!memTy.getElementType().isIntOrFloat())
71 op,
"Unsupported memref element type: expected integer or float");
77 VectorTransferOpInterface xferOp) {
80 "Masked transfer is not supported");
82 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
89 if (
failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
91 xferOp,
"Buffer must be contiguous in the innermost dimension");
93 VectorType vecTy = xferOp.getVectorType();
94 unsigned vecRank = vecTy.getRank();
95 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
97 xferOp,
"Boundary check is available only for block instructions.");
104 auto dim = dyn_cast<AffineDimExpr>(expr);
105 if (dim.getPosition() < (numInputDims - vecRank))
107 xferOp,
"Only the innermost dimensions can be accessed");
113static xegpu::CreateNdDescOp createNdDescriptor(
PatternRewriter &rewriter,
115 xegpu::TensorDescType descType,
117 MemRefType srcTy = src.getType();
118 assert(srcTy.isStrided() &&
"Expected strided memref type");
119 auto [strides, offset] = srcTy.getStridesAndOffset();
120 bool isStatic =
true;
123 if (!srcTy.hasStaticShape())
126 if (!ShapedType::isStatic(offset))
129 for (
auto stride : strides) {
130 if (!ShapedType::isStatic(stride)) {
136 xegpu::CreateNdDescOp ndDesc;
138 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
143 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
144 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
145 rewriter, loc, meta.getBaseBuffer());
146 auto offset = meta.getOffset();
147 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
148 auto offsetInBytes = arith::MulIOp::create(
149 rewriter, loc, offset,
151 auto adjustedBaseAddr = arith::AddIOp::create(
152 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
153 auto adjustedAddrI64 = arith::IndexCastOp::create(
154 rewriter, loc, rewriter.
getI64Type(), adjustedBaseAddr);
155 ndDesc = xegpu::CreateNdDescOp::create(
156 rewriter, loc, descType, adjustedAddrI64,
157 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
180static void adjustStridesForPermutation(
AffineMap permMap,
195 typename = std::enable_if_t<llvm::is_one_of<
196 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
197 vector::GatherOp, vector::ScatterOp>::value>>
198static std::pair<SmallVector<Value>,
Value>
201 Value baseMemref = xferOp.getBase();
202 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
205 Value offsetVal =
nullptr;
206 if (memrefType.hasStaticShape()) {
209 if (
failed(memrefType.getStridesAndOffset(intStrides, offset)))
210 return {{}, offsetVal};
211 bool hasDynamicStrides = llvm::any_of(intStrides, [](
int64_t strideVal) {
212 return ShapedType::isDynamic(strideVal);
215 if (!hasDynamicStrides)
219 if (!ShapedType::isDynamic(offset))
223 if (strides.empty() || !offsetVal) {
226 unsigned rank = memrefType.getRank();
232 resultTypes.push_back(MemRefType::get(
233 {}, memrefType.getElementType()));
234 resultTypes.push_back(indexType);
236 for (
unsigned i = 0; i < rank; ++i)
237 resultTypes.push_back(indexType);
239 for (
unsigned i = 0; i < rank; ++i)
240 resultTypes.push_back(indexType);
242 auto meta = memref::ExtractStridedMetadataOp::create(
243 rewriter, loc, resultTypes, baseMemref);
246 strides.append(meta.getStrides().begin(), meta.getStrides().end());
249 offsetVal = meta.getOffset();
252 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
253 vector::TransferWriteOp>::value) {
256 adjustStridesForPermutation(permMap, strides);
259 return {strides, offsetVal};
290static Value computeOffsets(VectorTransferOpInterface xferOp,
294 VectorType vectorType = xferOp.getVectorType();
296 xferOp.getIndices().end());
302 auto stepType = VectorType::get({dim}, rewriter.
getIndexType());
303 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
304 stepVectors.push_back(stepOp);
309 size_t memrefRank = strides.size();
312 for (
size_t i = 0; i < vectorRank; ++i) {
313 size_t memrefDim = memrefRank - vectorRank + i;
314 Value strideValue = strides[memrefDim];
315 auto mulType = dyn_cast<VectorType>(stepVectors[i].
getType());
317 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
318 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
319 strideMultiplied.push_back(mulOp);
324 for (
size_t i = 0; i < vectorRank; ++i) {
327 auto newType = VectorType::get(newShape, rewriter.
getIndexType());
328 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
329 strideMultiplied[i]);
330 shapeCasted.push_back(castOp);
335 auto fullIndexVectorType =
337 for (
Value shapeCastVal : shapeCasted) {
338 auto broadcastOp = vector::BroadcastOp::create(
339 rewriter, loc, fullIndexVectorType, shapeCastVal);
340 broadcasted.push_back(broadcastOp);
344 Value localOffsets = broadcasted[0];
345 for (
size_t i = 1; i < broadcasted.size(); ++i)
347 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
350 for (
size_t i = 0; i <
indices.size(); ++i) {
351 Value strideVal = strides[i];
352 Value offsetContrib =
353 arith::MulIOp::create(rewriter, loc,
indices[i], strideVal);
355 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
358 Value bcastBase = vector::BroadcastOp::create(
359 rewriter, loc, fullIndexVectorType, baseOffset);
360 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
371 typename = std::enable_if_t<llvm::is_one_of<
372 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
377 for (
size_t i = 0; i < offsets.size(); ++i) {
378 Value offsetContrib =
379 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
381 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
384 VectorType vecType = cast<VectorType>(
indices.getType());
387 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
389 Value stridedIndices =
390 arith::MulIOp::create(rewriter, loc, strideVector,
indices).getResult();
393 vector::BroadcastOp::create(
395 VectorType::get(vecType.getShape(), rewriter.
getIndexType()),
398 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
407static std::pair<Value, SmallVector<OpFoldResult>>
412 auto memrefType = cast<MemRefType>(
memref.getType());
413 unsigned rank = memrefType.getRank();
415 if (rank <= targetRank)
418 int64_t numCombinedDims = rank - targetRank;
424 for (
unsigned i = 0; i < numCombinedDims; ++i) {
425 subviewOffsets.push_back(offsets[i]);
432 auto originalShape = memrefType.getShape();
433 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc,
memref);
434 for (
unsigned i = numCombinedDims; i < rank; ++i) {
436 if (ShapedType::isDynamic(originalShape[i])) {
437 subviewSizes.push_back(meta.getSizes()[i]);
438 resultShape.push_back(ShapedType::kDynamic);
441 resultShape.push_back(originalShape[i]);
446 auto resultType = memref::SubViewOp::inferRankReducedResultType(
447 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
449 memref::SubViewOp::create(rewriter, loc, resultType,
memref,
450 subviewOffsets, subviewSizes, subviewStrides);
455 return {subviewOp.getResult(), newOffsets};
460 typename = std::enable_if_t<llvm::is_one_of<
461 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
462 vector::GatherOp, vector::ScatterOp>::value>>
466 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
467 rewriter, loc, xferOp.getBase())
469 return arith::IndexCastOp::create(rewriter, loc, rewriter.
getI64Type(),
474static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
478 VectorType vectorType = readOp.getVectorType();
480 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
484 auto meta = computeMemrefMeta(readOp, rewriter);
485 if (meta.first.empty())
489 computeOffsets(readOp, rewriter, meta.first, meta.second);
491 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
493 Value mask = vector::ConstantMaskOp::create(
496 auto gatherOp = xegpu::LoadGatherOp::create(
497 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
499 xegpu::CachePolicyAttr{},
500 xegpu::CachePolicyAttr{},
501 xegpu::CachePolicyAttr{},
504 rewriter.
replaceOp(readOp, gatherOp.getResult());
508static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
512 VectorType vectorType = writeOp.getVectorType();
515 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
519 auto meta = computeMemrefMeta(writeOp, rewriter);
520 if (meta.first.empty())
524 computeOffsets(writeOp, rewriter, meta.first, meta.second);
526 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
528 Value mask = vector::ConstantMaskOp::create(
531 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
534 xegpu::CachePolicyAttr{},
535 xegpu::CachePolicyAttr{},
536 xegpu::CachePolicyAttr{},
542struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
545 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
546 PatternRewriter &rewriter)
const override {
547 Location loc = readOp.getLoc();
549 if (
failed(transferPreconditions(rewriter, readOp)))
551 auto readMemTy = cast<MemRefType>(readOp.getShapedType());
552 VectorType loadedVecTy = readOp.getVectorType();
553 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
555 bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(readMemTy);
559 if (loadedVecTy.getRank() != 2)
561 readOp,
"Only 2D vector loads are supported for SLM");
562 AffineMap readMap = readOp.getPermutationMap();
566 "Non identity transposition is not supported for SLM loads.");
570 readOp,
"Out-of-bounds access is not supported for SLM loads");
574 xegpu::MemDescType::get(rewriter.
getContext(), readMemTy.getShape(),
575 readMemTy.getElementType(),
577 auto createMemDescOp = xegpu::CreateMemDescOp::create(
578 rewriter, loc, memDescType, readOp.getBase());
580 SmallVector<OpFoldResult>
indices =
582 auto loadMatrixOp = xegpu::LoadMatrixOp::create(
583 rewriter, loc, loadedVecTy, createMemDescOp.getResult(),
indices,
586 rewriter.
replaceOp(readOp, loadMatrixOp.getResult());
594 if ((chip !=
"pvc" && chip !=
"bmg" && chip !=
"cri") ||
595 readOp.getVectorType().getRank() > 2) {
600 return lowerToScatteredLoadOp(readOp, rewriter);
604 if (loadedVecTy.getRank() == 1 && !isOutOfBounds)
605 return lowerToScatteredLoadOp(readOp, rewriter);
610 storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
613 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
615 readOp,
"Unsupported non-zero padded out-of-bounds read");
624 bool isTransposeLoad =
false;
628 if (numInputs >= 2) {
633 (results[0] == lastDim && results[1] == secondLastDim);
636 auto elementType = loadedVecTy.getElementType();
638 SmallVector<int64_t> descShape(loadedVecTy.getShape());
639 if (isTransposeLoad) {
642 size_t rank = descShape.size();
643 assert(rank >= 2 &&
"Transpose requires at least 2 dimensions");
644 std::swap(descShape[rank - 1], descShape[rank - 2]);
645 loadedVecTy = VectorType::get(descShape, elementType);
647 auto descType = xegpu::TensorDescType::get(
648 descShape, elementType, 1,
649 isOutOfBounds, xegpu::MemorySpace::Global);
650 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
652 loadedVecTy.getRank());
654 xegpu::CachePolicyAttr hint =
nullptr;
655 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
658 Operation *loadedOp =
659 xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc,
indices,
664 if (isTransposeLoad) {
667 auto range = llvm::seq<int64_t>(0, readMap.
getResults().size());
668 SmallVector<int64_t> perm(
669 range.rbegin(), range.rend());
670 loadedOp = vector::TransposeOp::create(rewriter, loc,
679struct TransferWriteLowering
683 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
684 PatternRewriter &rewriter)
const override {
685 Location loc = writeOp.getLoc();
687 if (
failed(transferPreconditions(rewriter, writeOp)))
690 VectorType vecTy = writeOp.getVectorType();
691 auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
693 bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(writeMemTy);
699 if (vecTy.getRank() != 2)
701 writeOp,
"Only 2D vector stores are supported for SLM");
704 xegpu::MemDescType::get(rewriter.
getContext(), writeMemTy.getShape(),
705 writeMemTy.getElementType(),
708 auto createMemDescOp = xegpu::CreateMemDescOp::create(
709 rewriter, loc, memDescType, writeOp.getBase());
712 SmallVector<OpFoldResult>
indices =
715 xegpu::StoreMatrixOp::create(rewriter, loc, writeOp.getVector(),
716 createMemDescOp.getResult(),
indices,
727 if ((chip !=
"pvc" && chip !=
"bmg" && chip !=
"cri") ||
728 writeOp.getVectorType().getRank() > 2) {
731 if (writeOp.hasOutOfBoundsDim())
733 return lowerToScatteredStoreOp(writeOp, rewriter);
736 if (
failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
743 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
744 rewriter, loc, writeOp.getBase(),
747 auto descType = xegpu::TensorDescType::get(
748 vecTy.getShape(), vecTy.getElementType(),
749 1, writeOp.hasOutOfBoundsDim(),
750 xegpu::MemorySpace::Global);
752 xegpu::CachePolicyAttr hint =
nullptr;
753 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
756 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
770 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
771 PatternRewriter &rewriter)
const override {
772 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
776 Location loc = gatherOp.getLoc();
777 VectorType vectorType = gatherOp.getVectorType();
779 auto meta = computeMemrefMeta(gatherOp, rewriter);
780 if (meta.first.empty())
784 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
785 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
787 auto xeGatherOp = xegpu::LoadGatherOp::create(
788 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
790 xegpu::CachePolicyAttr{},
791 xegpu::CachePolicyAttr{},
792 xegpu::CachePolicyAttr{},
796 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
797 xeGatherOp.getResult(), gatherOp.getPassThru());
798 rewriter.
replaceOp(gatherOp, selectOp.getResult());
806 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
807 PatternRewriter &rewriter)
const override {
808 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
812 Location loc = scatterOp.getLoc();
813 auto meta = computeMemrefMeta(scatterOp, rewriter);
814 if (meta.first.empty())
816 "Failed to compute strides");
819 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
820 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
822 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
823 flatMemref, localOffsets, scatterOp.getMask(),
825 xegpu::CachePolicyAttr{},
826 xegpu::CachePolicyAttr{},
827 xegpu::CachePolicyAttr{},
837 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
838 PatternRewriter &rewriter)
const override {
839 Location loc = loadOp.getLoc();
841 VectorType vecTy = loadOp.getResult().getType();
842 MemRefType memTy = loadOp.getBase().getType();
843 if (
failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
847 bool boundaryCheck = vecTy.getRank() > 1;
849 xegpu::CachePolicyAttr hint =
nullptr;
851 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
855 auto descType = xegpu::TensorDescType::get(
856 vecTy.getShape(), vecTy.getElementType(), 1,
857 boundaryCheck, xegpu::MemorySpace::Global);
859 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
862 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
876 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
877 PatternRewriter &rewriter)
const override {
878 Location loc = storeOp.getLoc();
881 VectorType vecTy = vector.getType();
882 MemRefType memTy = storeOp.getBase().getType();
883 if (
failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy)))
887 bool boundaryCheck = vecTy.getRank() > 1;
889 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
890 rewriter, loc, storeOp.getBase(),
893 auto descType = xegpu::TensorDescType::get(
894 vecTy.getShape(), vecTy.getElementType(),
895 1, boundaryCheck, xegpu::MemorySpace::Global);
898 xegpu::CachePolicyAttr hint =
nullptr;
899 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
903 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
indices,
914struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
917 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
918 PatternRewriter &rewriter)
const override {
919 Location loc = contractOp.getLoc();
921 if (contractOp.getKind() != vector::CombiningKind::ADD)
923 "Expects add combining kind");
926 VectorType accType = dyn_cast<VectorType>(acc.getType());
927 if (!accType || accType.getRank() != 2)
934 if (
lhs.getType().getRank() != 2 ||
rhs.getType().getRank() != 2)
936 "Expects lhs and rhs 2D vectors");
941 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
951static MemRefType withMemorySpace(MemRefType memrefTy,
Attribute newMemSpace) {
952 return MemRefType::get(memrefTy.getShape(), memrefTy.getElementType(),
953 memrefTy.getLayout(), newMemSpace);
966static void promoteAllocasToSLM(
Operation *root) {
968 Attribute slmAttr = IntegerAttr::get(IntegerType::get(ctx, 64), 3);
974 auto isMemrefResultOp = [](
Operation *op) {
978 [](
Type t) { return isa<MemRefType>(t); });
984 auto memrefTy = dyn_cast<MemRefType>(v.getType());
985 if (!memrefTy || xegpu::XeGPUDialect::isSharedMemory(memrefTy))
987 v.setType(withMemorySpace(memrefTy, slmAttr));
989 if (!isMemrefResultOp(user))
997 root->
walk([&](memref::AllocaOp op) {
998 auto memrefTy = dyn_cast<MemRefType>(op.getResult().getType());
999 if (!memrefTy || xegpu::XeGPUDialect::isSharedMemory(memrefTy))
1001 allocas.push_back(op);
1004 for (memref::AllocaOp alloca : allocas) {
1006 auto memrefTy = cast<MemRefType>(alloca.getResult().getType());
1007 auto newTy = withMemorySpace(memrefTy, slmAttr);
1008 auto newOp = memref::AllocaOp::create(
1009 builder, alloca.getLoc(), newTy, alloca.getDynamicSizes(),
1010 alloca.getSymbolOperands(), alloca.getAlignmentAttr());
1011 alloca.getResult().replaceAllUsesWith(newOp.getResult());
1015 if (!isMemrefResultOp(user))
1023struct ConvertVectorToXeGPUPass
1024 :
public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
1025 void runOnOperation()
override {
1028 promoteAllocasToSLM(getOperation());
1034 return signalPassFailure();
1043 .
add<TransferReadLowering, TransferWriteLowering, LoadLowering,
1044 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
static std::optional< VectorShape > vectorShape(Type type)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Attributes are known-constant values of operations.
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_type_range getResultTypes()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...