24#include "llvm/ADT/TypeSwitch.h"
30#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31#include "mlir/Conversion/Passes.h.inc"
39static bool isZeroConstant(
Value val) {
46 [](
auto floatAttr) {
return floatAttr.getValue().isZero(); })
48 [](
auto intAttr) {
return intAttr.getValue().isZero(); })
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
64 VectorTransferOpInterface xferOp) {
67 "Masked transfer is not supported");
69 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
76 if (
failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
78 xferOp,
"Buffer must be contiguous in the innermost dimension");
80 VectorType vecTy = xferOp.getVectorType();
81 unsigned vecRank = vecTy.getRank();
82 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
84 xferOp,
"Boundary check is available only for block instructions.");
91 auto dim = dyn_cast<AffineDimExpr>(expr);
92 if (dim.getPosition() < (numInputDims - vecRank))
94 xferOp,
"Only the innermost dimensions can be accessed");
100static xegpu::CreateNdDescOp createNdDescriptor(
PatternRewriter &rewriter,
102 xegpu::TensorDescType descType,
104 MemRefType srcTy = src.getType();
105 assert(srcTy.isStrided() &&
"Expected strided memref type");
106 auto [strides, offset] = srcTy.getStridesAndOffset();
107 bool isStatic =
true;
110 if (!srcTy.hasStaticShape())
113 if (!ShapedType::isStatic(offset))
116 for (
auto stride : strides) {
117 if (!ShapedType::isStatic(stride)) {
123 xegpu::CreateNdDescOp ndDesc;
125 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
130 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
131 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
132 rewriter, loc, meta.getBaseBuffer());
133 auto offset = meta.getOffset();
134 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
135 auto offsetInBytes = arith::MulIOp::create(
136 rewriter, loc, offset,
138 auto adjustedBaseAddr = arith::AddIOp::create(
139 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
140 auto adjustedAddrI64 = arith::IndexCastOp::create(
141 rewriter, loc, rewriter.
getI64Type(), adjustedBaseAddr);
142 ndDesc = xegpu::CreateNdDescOp::create(
143 rewriter, loc, descType, adjustedAddrI64,
144 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
167static void adjustStridesForPermutation(
AffineMap permMap,
182 typename = std::enable_if_t<llvm::is_one_of<
183 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
184 vector::GatherOp, vector::ScatterOp>::value>>
185static std::pair<SmallVector<Value>,
Value>
188 Value baseMemref = xferOp.getBase();
189 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
192 Value offsetVal =
nullptr;
193 if (memrefType.hasStaticShape()) {
196 if (
failed(memrefType.getStridesAndOffset(intStrides, offset)))
197 return {{}, offsetVal};
198 bool hasDynamicStrides = llvm::any_of(intStrides, [](
int64_t strideVal) {
199 return ShapedType::isDynamic(strideVal);
202 if (!hasDynamicStrides)
206 if (!ShapedType::isDynamic(offset))
210 if (strides.empty() || !offsetVal) {
213 unsigned rank = memrefType.getRank();
219 resultTypes.push_back(MemRefType::get(
220 {}, memrefType.getElementType()));
221 resultTypes.push_back(indexType);
223 for (
unsigned i = 0; i < rank; ++i)
224 resultTypes.push_back(indexType);
226 for (
unsigned i = 0; i < rank; ++i)
227 resultTypes.push_back(indexType);
229 auto meta = memref::ExtractStridedMetadataOp::create(
230 rewriter, loc, resultTypes, baseMemref);
233 strides.append(meta.getStrides().begin(), meta.getStrides().end());
236 offsetVal = meta.getOffset();
239 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
240 vector::TransferWriteOp>::value) {
243 adjustStridesForPermutation(permMap, strides);
246 return {strides, offsetVal};
277static Value computeOffsets(VectorTransferOpInterface xferOp,
281 VectorType vectorType = xferOp.getVectorType();
283 xferOp.getIndices().end());
289 auto stepType = VectorType::get({dim}, rewriter.
getIndexType());
290 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
291 stepVectors.push_back(stepOp);
296 size_t memrefRank = strides.size();
299 for (
size_t i = 0; i < vectorRank; ++i) {
300 size_t memrefDim = memrefRank - vectorRank + i;
301 Value strideValue = strides[memrefDim];
302 auto mulType = dyn_cast<VectorType>(stepVectors[i].
getType());
304 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
305 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
306 strideMultiplied.push_back(mulOp);
311 for (
size_t i = 0; i < vectorRank; ++i) {
314 auto newType = VectorType::get(newShape, rewriter.
getIndexType());
315 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
316 strideMultiplied[i]);
317 shapeCasted.push_back(castOp);
322 auto fullIndexVectorType =
324 for (
Value shapeCastVal : shapeCasted) {
325 auto broadcastOp = vector::BroadcastOp::create(
326 rewriter, loc, fullIndexVectorType, shapeCastVal);
327 broadcasted.push_back(broadcastOp);
331 Value localOffsets = broadcasted[0];
332 for (
size_t i = 1; i < broadcasted.size(); ++i)
334 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
337 for (
size_t i = 0; i <
indices.size(); ++i) {
338 Value strideVal = strides[i];
339 Value offsetContrib =
340 arith::MulIOp::create(rewriter, loc,
indices[i], strideVal);
342 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
345 Value bcastBase = vector::BroadcastOp::create(
346 rewriter, loc, fullIndexVectorType, baseOffset);
347 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
358 typename = std::enable_if_t<llvm::is_one_of<
359 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
364 for (
size_t i = 0; i < offsets.size(); ++i) {
365 Value offsetContrib =
366 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
368 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
371 VectorType vecType = cast<VectorType>(
indices.getType());
374 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
376 Value stridedIndices =
377 arith::MulIOp::create(rewriter, loc, strideVector,
indices).getResult();
380 vector::BroadcastOp::create(
382 VectorType::get(vecType.getShape(), rewriter.
getIndexType()),
385 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
394static std::pair<Value, SmallVector<OpFoldResult>>
399 auto memrefType = cast<MemRefType>(
memref.getType());
400 unsigned rank = memrefType.getRank();
402 if (rank <= targetRank)
405 int64_t numCombinedDims = rank - targetRank;
411 for (
unsigned i = 0; i < numCombinedDims; ++i) {
412 subviewOffsets.push_back(offsets[i]);
419 auto originalShape = memrefType.getShape();
420 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc,
memref);
421 for (
unsigned i = numCombinedDims; i < rank; ++i) {
423 if (ShapedType::isDynamic(originalShape[i])) {
424 subviewSizes.push_back(meta.getSizes()[i]);
425 resultShape.push_back(ShapedType::kDynamic);
428 resultShape.push_back(originalShape[i]);
433 auto resultType = memref::SubViewOp::inferRankReducedResultType(
434 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
436 memref::SubViewOp::create(rewriter, loc, resultType,
memref,
437 subviewOffsets, subviewSizes, subviewStrides);
442 return {subviewOp.getResult(), newOffsets};
447 typename = std::enable_if_t<llvm::is_one_of<
448 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
449 vector::GatherOp, vector::ScatterOp>::value>>
453 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
454 rewriter, loc, xferOp.getBase())
456 return arith::IndexCastOp::create(rewriter, loc, rewriter.
getI64Type(),
461static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
465 VectorType vectorType = readOp.getVectorType();
467 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
471 auto meta = computeMemrefMeta(readOp, rewriter);
472 if (meta.first.empty())
476 computeOffsets(readOp, rewriter, meta.first, meta.second);
478 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
480 Value mask = vector::ConstantMaskOp::create(
483 auto gatherOp = xegpu::LoadGatherOp::create(
484 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
486 xegpu::CachePolicyAttr{},
487 xegpu::CachePolicyAttr{},
488 xegpu::CachePolicyAttr{},
491 rewriter.
replaceOp(readOp, gatherOp.getResult());
495static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
499 VectorType vectorType = writeOp.getVectorType();
502 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
506 auto meta = computeMemrefMeta(writeOp, rewriter);
507 if (meta.first.empty())
511 computeOffsets(writeOp, rewriter, meta.first, meta.second);
513 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
515 Value mask = vector::ConstantMaskOp::create(
518 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
521 xegpu::CachePolicyAttr{},
522 xegpu::CachePolicyAttr{},
523 xegpu::CachePolicyAttr{},
529struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
532 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
533 PatternRewriter &rewriter)
const override {
534 Location loc = readOp.getLoc();
536 if (
failed(transferPreconditions(rewriter, readOp)))
541 if (chip !=
"pvc" && chip !=
"bmg") {
545 if (readOp.hasOutOfBoundsDim())
547 return lowerToScatteredLoadOp(readOp, rewriter);
550 VectorType vecTy = readOp.getVectorType();
553 if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
554 return lowerToScatteredLoadOp(readOp, rewriter);
557 if (
failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
560 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
561 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
563 readOp,
"Unsupported non-zero padded out-of-bounds read");
565 AffineMap readMap = readOp.getPermutationMap();
568 Type elementType = vecTy.getElementType();
569 unsigned minTransposeBitWidth = 32;
570 if (isTransposeLoad &&
573 readOp,
"Unsupported data type for transposition");
576 SmallVector<int64_t> descShape(vecTy.getShape());
578 std::reverse(descShape.begin(), descShape.end());
579 auto descType = xegpu::TensorDescType::get(
580 descShape, elementType, 1,
581 isOutOfBounds, xegpu::MemorySpace::Global);
586 ArrayRef<int64_t>{1, 0});
587 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
591 xegpu::CachePolicyAttr hint =
nullptr;
592 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
595 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
596 nullptr, transposeAttr,
606struct TransferWriteLowering
610 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
611 PatternRewriter &rewriter)
const override {
612 Location loc = writeOp.getLoc();
614 if (
failed(transferPreconditions(rewriter, writeOp)))
619 if (chip !=
"pvc" && chip !=
"bmg") {
623 if (writeOp.hasOutOfBoundsDim())
625 return lowerToScatteredStoreOp(writeOp, rewriter);
629 VectorType vecTy = writeOp.getVectorType();
630 if (
failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
637 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
638 rewriter, loc, writeOp.getBase(),
641 auto descType = xegpu::TensorDescType::get(
642 vecTy.getShape(), vecTy.getElementType(),
643 1, writeOp.hasOutOfBoundsDim(),
644 xegpu::MemorySpace::Global);
646 xegpu::CachePolicyAttr hint =
nullptr;
647 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
650 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
664 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
665 PatternRewriter &rewriter)
const override {
666 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
670 Location loc = gatherOp.getLoc();
671 VectorType vectorType = gatherOp.getVectorType();
673 auto meta = computeMemrefMeta(gatherOp, rewriter);
674 if (meta.first.empty())
678 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
679 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
681 auto xeGatherOp = xegpu::LoadGatherOp::create(
682 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
684 xegpu::CachePolicyAttr{},
685 xegpu::CachePolicyAttr{},
686 xegpu::CachePolicyAttr{},
690 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
691 xeGatherOp.getResult(), gatherOp.getPassThru());
692 rewriter.
replaceOp(gatherOp, selectOp.getResult());
700 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
701 PatternRewriter &rewriter)
const override {
702 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
706 Location loc = scatterOp.getLoc();
707 auto meta = computeMemrefMeta(scatterOp, rewriter);
708 if (meta.first.empty())
710 "Failed to compute strides");
713 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
714 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
716 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
717 flatMemref, localOffsets, scatterOp.getMask(),
719 xegpu::CachePolicyAttr{},
720 xegpu::CachePolicyAttr{},
721 xegpu::CachePolicyAttr{},
731 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
732 PatternRewriter &rewriter)
const override {
733 Location loc = loadOp.getLoc();
735 VectorType vecTy = loadOp.getResult().getType();
736 if (
failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
740 bool boundaryCheck = vecTy.getRank() > 1;
742 xegpu::CachePolicyAttr hint =
nullptr;
744 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
748 auto descType = xegpu::TensorDescType::get(
749 vecTy.getShape(), vecTy.getElementType(), 1,
750 boundaryCheck, xegpu::MemorySpace::Global);
752 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
755 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
769 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
770 PatternRewriter &rewriter)
const override {
771 Location loc = storeOp.getLoc();
774 VectorType vecTy = vector.getType();
775 if (
failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
779 bool boundaryCheck = vecTy.getRank() > 1;
781 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
782 rewriter, loc, storeOp.getBase(),
785 auto descType = xegpu::TensorDescType::get(
786 vecTy.getShape(), vecTy.getElementType(),
787 1, boundaryCheck, xegpu::MemorySpace::Global);
790 xegpu::CachePolicyAttr hint =
nullptr;
791 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
795 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
indices,
806struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
809 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
810 PatternRewriter &rewriter)
const override {
811 Location loc = contractOp.getLoc();
813 if (contractOp.getKind() != vector::CombiningKind::ADD)
815 "Expects add combining kind");
818 VectorType accType = dyn_cast<VectorType>(acc.getType());
819 if (!accType || accType.getRank() != 2)
826 if (
lhs.getType().getRank() != 2 ||
rhs.getType().getRank() != 2)
828 "Expects lhs and rhs 2D vectors");
833 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
842struct ConvertVectorToXeGPUPass
844 void runOnOperation()
override {
848 return signalPassFailure();
857 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
858 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...