24#include "llvm/ADT/TypeSwitch.h"
30#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31#include "mlir/Conversion/Passes.h.inc"
39static bool isZeroConstant(
Value val) {
46 [](
auto floatAttr) {
return floatAttr.getValue().isZero(); })
48 [](
auto intAttr) {
return intAttr.getValue().isZero(); })
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
64 VectorTransferOpInterface xferOp) {
67 "Masked transfer is not supported");
69 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
76 if (
failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
78 xferOp,
"Buffer must be contiguous in the innermost dimension");
80 VectorType vecTy = xferOp.getVectorType();
81 unsigned vecRank = vecTy.getRank();
82 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
84 xferOp,
"Boundary check is available only for block instructions.");
91 auto dim = dyn_cast<AffineDimExpr>(expr);
92 if (dim.getPosition() < (numInputDims - vecRank))
94 xferOp,
"Only the innermost dimensions can be accessed");
100static xegpu::CreateNdDescOp createNdDescriptor(
PatternRewriter &rewriter,
102 xegpu::TensorDescType descType,
104 MemRefType srcTy = src.getType();
105 auto [strides, offset] = srcTy.getStridesAndOffset();
107 xegpu::CreateNdDescOp ndDesc;
108 if (srcTy.hasStaticShape()) {
109 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
113 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
114 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
115 meta.getConstifiedMixedSizes(),
116 meta.getConstifiedMixedStrides());
139static void adjustStridesForPermutation(
AffineMap permMap,
154 typename = std::enable_if_t<llvm::is_one_of<
155 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
156 vector::GatherOp, vector::ScatterOp>::value>>
157static std::pair<SmallVector<Value>,
Value>
160 Value baseMemref = xferOp.getBase();
161 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
164 Value offsetVal =
nullptr;
165 if (memrefType.hasStaticShape()) {
168 if (
failed(memrefType.getStridesAndOffset(intStrides, offset)))
169 return {{}, offsetVal};
170 bool hasDynamicStrides = llvm::any_of(intStrides, [](
int64_t strideVal) {
171 return ShapedType::isDynamic(strideVal);
174 if (!hasDynamicStrides)
178 if (!ShapedType::isDynamic(offset))
182 if (strides.empty() || !offsetVal) {
185 unsigned rank = memrefType.getRank();
191 resultTypes.push_back(MemRefType::get(
192 {}, memrefType.getElementType()));
193 resultTypes.push_back(indexType);
195 for (
unsigned i = 0; i < rank; ++i)
196 resultTypes.push_back(indexType);
198 for (
unsigned i = 0; i < rank; ++i)
199 resultTypes.push_back(indexType);
201 auto meta = memref::ExtractStridedMetadataOp::create(
202 rewriter, loc, resultTypes, baseMemref);
205 strides.append(meta.getStrides().begin(), meta.getStrides().end());
208 offsetVal = meta.getOffset();
211 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
212 vector::TransferWriteOp>::value) {
215 adjustStridesForPermutation(permMap, strides);
218 return {strides, offsetVal};
249static Value computeOffsets(VectorTransferOpInterface xferOp,
253 VectorType vectorType = xferOp.getVectorType();
255 xferOp.getIndices().end());
261 auto stepType = VectorType::get({dim}, rewriter.
getIndexType());
262 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
263 stepVectors.push_back(stepOp);
268 size_t memrefRank = strides.size();
271 for (
size_t i = 0; i < vectorRank; ++i) {
272 size_t memrefDim = memrefRank - vectorRank + i;
273 Value strideValue = strides[memrefDim];
274 auto mulType = dyn_cast<VectorType>(stepVectors[i].
getType());
276 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
277 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
278 strideMultiplied.push_back(mulOp);
283 for (
size_t i = 0; i < vectorRank; ++i) {
286 auto newType = VectorType::get(newShape, rewriter.
getIndexType());
287 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
288 strideMultiplied[i]);
289 shapeCasted.push_back(castOp);
294 auto fullIndexVectorType =
296 for (
Value shapeCastVal : shapeCasted) {
297 auto broadcastOp = vector::BroadcastOp::create(
298 rewriter, loc, fullIndexVectorType, shapeCastVal);
299 broadcasted.push_back(broadcastOp);
303 Value localOffsets = broadcasted[0];
304 for (
size_t i = 1; i < broadcasted.size(); ++i)
306 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
309 for (
size_t i = 0; i <
indices.size(); ++i) {
310 Value strideVal = strides[i];
311 Value offsetContrib =
312 arith::MulIOp::create(rewriter, loc,
indices[i], strideVal);
314 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
317 Value bcastBase = vector::BroadcastOp::create(
318 rewriter, loc, fullIndexVectorType, baseOffset);
319 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
330 typename = std::enable_if_t<llvm::is_one_of<
331 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
336 for (
size_t i = 0; i < offsets.size(); ++i) {
337 Value offsetContrib =
338 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
340 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
343 VectorType vecType = cast<VectorType>(
indices.getType());
346 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
348 Value stridedIndices =
349 arith::MulIOp::create(rewriter, loc, strideVector,
indices).getResult();
352 vector::BroadcastOp::create(
354 VectorType::get(vecType.getShape(), rewriter.
getIndexType()),
357 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
366static std::pair<Value, SmallVector<OpFoldResult>>
371 auto memrefType = cast<MemRefType>(
memref.getType());
372 unsigned rank = memrefType.getRank();
374 if (rank <= targetRank)
377 int64_t numCombinedDims = rank - targetRank;
383 for (
unsigned i = 0; i < numCombinedDims; ++i) {
384 subviewOffsets.push_back(offsets[i]);
391 auto originalShape = memrefType.getShape();
392 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc,
memref);
393 for (
unsigned i = numCombinedDims; i < rank; ++i) {
395 if (ShapedType::isDynamic(originalShape[i])) {
396 subviewSizes.push_back(meta.getSizes()[i]);
397 resultShape.push_back(ShapedType::kDynamic);
400 resultShape.push_back(originalShape[i]);
405 auto resultType = memref::SubViewOp::inferRankReducedResultType(
406 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
408 memref::SubViewOp::create(rewriter, loc, resultType,
memref,
409 subviewOffsets, subviewSizes, subviewStrides);
414 return {subviewOp.getResult(), newOffsets};
419 typename = std::enable_if_t<llvm::is_one_of<
420 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
421 vector::GatherOp, vector::ScatterOp>::value>>
425 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
426 rewriter, loc, xferOp.getBase())
428 return arith::IndexCastOp::create(rewriter, loc, rewriter.
getI64Type(),
433static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
437 VectorType vectorType = readOp.getVectorType();
439 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
443 auto meta = computeMemrefMeta(readOp, rewriter);
444 if (meta.first.empty())
448 computeOffsets(readOp, rewriter, meta.first, meta.second);
450 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
452 Value mask = vector::ConstantMaskOp::create(
455 auto gatherOp = xegpu::LoadGatherOp::create(
456 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
458 xegpu::CachePolicyAttr{},
459 xegpu::CachePolicyAttr{},
460 xegpu::CachePolicyAttr{},
463 rewriter.
replaceOp(readOp, gatherOp.getResult());
467static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
471 VectorType vectorType = writeOp.getVectorType();
474 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
478 auto meta = computeMemrefMeta(writeOp, rewriter);
479 if (meta.first.empty())
483 computeOffsets(writeOp, rewriter, meta.first, meta.second);
485 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
487 Value mask = vector::ConstantMaskOp::create(
490 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
493 xegpu::CachePolicyAttr{},
494 xegpu::CachePolicyAttr{},
495 xegpu::CachePolicyAttr{},
501struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
504 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
505 PatternRewriter &rewriter)
const override {
506 Location loc = readOp.getLoc();
508 if (
failed(transferPreconditions(rewriter, readOp)))
513 if (chip !=
"pvc" && chip !=
"bmg") {
517 if (readOp.hasOutOfBoundsDim())
519 return lowerToScatteredLoadOp(readOp, rewriter);
523 VectorType vecTy = readOp.getVectorType();
524 if (
failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
527 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
528 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
530 readOp,
"Unsupported non-zero padded out-of-bounds read");
532 AffineMap readMap = readOp.getPermutationMap();
535 Type elementType = vecTy.getElementType();
536 unsigned minTransposeBitWidth = 32;
537 if (isTransposeLoad &&
540 readOp,
"Unsupported data type for transposition");
543 SmallVector<int64_t> descShape(vecTy.getShape());
545 std::reverse(descShape.begin(), descShape.end());
546 auto descType = xegpu::TensorDescType::get(
547 descShape, elementType, 1,
548 isOutOfBounds, xegpu::MemorySpace::Global);
553 ArrayRef<int64_t>{1, 0});
554 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
558 xegpu::CachePolicyAttr hint =
nullptr;
559 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
562 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
563 nullptr, transposeAttr,
572struct TransferWriteLowering
576 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
577 PatternRewriter &rewriter)
const override {
578 Location loc = writeOp.getLoc();
580 if (
failed(transferPreconditions(rewriter, writeOp)))
585 if (chip !=
"pvc" && chip !=
"bmg") {
589 if (writeOp.hasOutOfBoundsDim())
591 return lowerToScatteredStoreOp(writeOp, rewriter);
595 VectorType vecTy = writeOp.getVectorType();
596 if (
failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
603 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
604 rewriter, loc, writeOp.getBase(),
607 auto descType = xegpu::TensorDescType::get(
608 vecTy.getShape(), vecTy.getElementType(),
609 1, writeOp.hasOutOfBoundsDim(),
610 xegpu::MemorySpace::Global);
612 xegpu::CachePolicyAttr hint =
nullptr;
613 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
616 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
629 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
630 PatternRewriter &rewriter)
const override {
631 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
635 Location loc = gatherOp.getLoc();
636 VectorType vectorType = gatherOp.getVectorType();
638 auto meta = computeMemrefMeta(gatherOp, rewriter);
639 if (meta.first.empty())
643 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
644 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
646 auto xeGatherOp = xegpu::LoadGatherOp::create(
647 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
649 xegpu::CachePolicyAttr{},
650 xegpu::CachePolicyAttr{},
651 xegpu::CachePolicyAttr{},
655 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
656 xeGatherOp.getResult(), gatherOp.getPassThru());
657 rewriter.
replaceOp(gatherOp, selectOp.getResult());
665 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
666 PatternRewriter &rewriter)
const override {
667 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
671 Location loc = scatterOp.getLoc();
672 auto meta = computeMemrefMeta(scatterOp, rewriter);
673 if (meta.first.empty())
675 "Failed to compute strides");
678 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
679 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
681 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
682 flatMemref, localOffsets, scatterOp.getMask(),
684 xegpu::CachePolicyAttr{},
685 xegpu::CachePolicyAttr{},
686 xegpu::CachePolicyAttr{},
696 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
697 PatternRewriter &rewriter)
const override {
698 Location loc = loadOp.getLoc();
700 VectorType vecTy = loadOp.getResult().getType();
701 if (
failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
705 bool boundaryCheck = vecTy.getRank() > 1;
707 xegpu::CachePolicyAttr hint =
nullptr;
709 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
713 auto descType = xegpu::TensorDescType::get(
714 vecTy.getShape(), vecTy.getElementType(), 1,
715 boundaryCheck, xegpu::MemorySpace::Global);
717 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
720 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
733 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
734 PatternRewriter &rewriter)
const override {
735 Location loc = storeOp.getLoc();
738 VectorType vecTy = vector.getType();
739 if (
failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
743 bool boundaryCheck = vecTy.getRank() > 1;
745 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
746 rewriter, loc, storeOp.getBase(),
749 auto descType = xegpu::TensorDescType::get(
750 vecTy.getShape(), vecTy.getElementType(),
751 1, boundaryCheck, xegpu::MemorySpace::Global);
754 xegpu::CachePolicyAttr hint =
nullptr;
755 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
759 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
indices,
769struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
772 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
773 PatternRewriter &rewriter)
const override {
774 Location loc = contractOp.getLoc();
776 if (contractOp.getKind() != vector::CombiningKind::ADD)
778 "Expects add combining kind");
781 VectorType accType = dyn_cast<VectorType>(acc.getType());
782 if (!accType || accType.getRank() != 2)
789 if (
lhs.getType().getRank() != 2 ||
rhs.getType().getRank() != 2)
791 "Expects lhs and rhs 2D vectors");
796 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
805struct ConvertVectorToXeGPUPass
807 void runOnOperation()
override {
811 return signalPassFailure();
820 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
821 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...