24#include "llvm/ADT/TypeSwitch.h"
30#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31#include "mlir/Conversion/Passes.h.inc"
39static bool isZeroConstant(
Value val) {
45 .Case([](FloatAttr floatAttr) {
return floatAttr.getValue().isZero(); })
46 .Case([](IntegerAttr intAttr) {
return intAttr.getValue().isZero(); })
54 unsigned vecRank = vecTy.getRank();
55 if (!(vecRank == 1 || vecRank == 2))
58 if (!vecTy.getElementType().isIntOrFloat())
60 op,
"Expected scalar type with known bitwidth");
66 VectorTransferOpInterface xferOp) {
69 "Masked transfer is not supported");
71 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
78 if (
failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
80 xferOp,
"Buffer must be contiguous in the innermost dimension");
82 VectorType vecTy = xferOp.getVectorType();
83 unsigned vecRank = vecTy.getRank();
84 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
86 xferOp,
"Boundary check is available only for block instructions.");
93 auto dim = dyn_cast<AffineDimExpr>(expr);
94 if (dim.getPosition() < (numInputDims - vecRank))
96 xferOp,
"Only the innermost dimensions can be accessed");
102static xegpu::CreateNdDescOp createNdDescriptor(
PatternRewriter &rewriter,
104 xegpu::TensorDescType descType,
106 MemRefType srcTy = src.getType();
107 assert(srcTy.isStrided() &&
"Expected strided memref type");
108 auto [strides, offset] = srcTy.getStridesAndOffset();
109 bool isStatic =
true;
112 if (!srcTy.hasStaticShape())
115 if (!ShapedType::isStatic(offset))
118 for (
auto stride : strides) {
119 if (!ShapedType::isStatic(stride)) {
125 xegpu::CreateNdDescOp ndDesc;
127 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
132 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
133 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
134 rewriter, loc, meta.getBaseBuffer());
135 auto offset = meta.getOffset();
136 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
137 auto offsetInBytes = arith::MulIOp::create(
138 rewriter, loc, offset,
140 auto adjustedBaseAddr = arith::AddIOp::create(
141 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
142 auto adjustedAddrI64 = arith::IndexCastOp::create(
143 rewriter, loc, rewriter.
getI64Type(), adjustedBaseAddr);
144 ndDesc = xegpu::CreateNdDescOp::create(
145 rewriter, loc, descType, adjustedAddrI64,
146 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
169static void adjustStridesForPermutation(
AffineMap permMap,
184 typename = std::enable_if_t<llvm::is_one_of<
185 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
186 vector::GatherOp, vector::ScatterOp>::value>>
187static std::pair<SmallVector<Value>,
Value>
190 Value baseMemref = xferOp.getBase();
191 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
194 Value offsetVal =
nullptr;
195 if (memrefType.hasStaticShape()) {
198 if (
failed(memrefType.getStridesAndOffset(intStrides, offset)))
199 return {{}, offsetVal};
200 bool hasDynamicStrides = llvm::any_of(intStrides, [](
int64_t strideVal) {
201 return ShapedType::isDynamic(strideVal);
204 if (!hasDynamicStrides)
208 if (!ShapedType::isDynamic(offset))
212 if (strides.empty() || !offsetVal) {
215 unsigned rank = memrefType.getRank();
221 resultTypes.push_back(MemRefType::get(
222 {}, memrefType.getElementType()));
223 resultTypes.push_back(indexType);
225 for (
unsigned i = 0; i < rank; ++i)
226 resultTypes.push_back(indexType);
228 for (
unsigned i = 0; i < rank; ++i)
229 resultTypes.push_back(indexType);
231 auto meta = memref::ExtractStridedMetadataOp::create(
232 rewriter, loc, resultTypes, baseMemref);
235 strides.append(meta.getStrides().begin(), meta.getStrides().end());
238 offsetVal = meta.getOffset();
241 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
242 vector::TransferWriteOp>::value) {
245 adjustStridesForPermutation(permMap, strides);
248 return {strides, offsetVal};
279static Value computeOffsets(VectorTransferOpInterface xferOp,
283 VectorType vectorType = xferOp.getVectorType();
285 xferOp.getIndices().end());
291 auto stepType = VectorType::get({dim}, rewriter.
getIndexType());
292 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
293 stepVectors.push_back(stepOp);
298 size_t memrefRank = strides.size();
301 for (
size_t i = 0; i < vectorRank; ++i) {
302 size_t memrefDim = memrefRank - vectorRank + i;
303 Value strideValue = strides[memrefDim];
304 auto mulType = dyn_cast<VectorType>(stepVectors[i].
getType());
306 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
307 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
308 strideMultiplied.push_back(mulOp);
313 for (
size_t i = 0; i < vectorRank; ++i) {
316 auto newType = VectorType::get(newShape, rewriter.
getIndexType());
317 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
318 strideMultiplied[i]);
319 shapeCasted.push_back(castOp);
324 auto fullIndexVectorType =
326 for (
Value shapeCastVal : shapeCasted) {
327 auto broadcastOp = vector::BroadcastOp::create(
328 rewriter, loc, fullIndexVectorType, shapeCastVal);
329 broadcasted.push_back(broadcastOp);
333 Value localOffsets = broadcasted[0];
334 for (
size_t i = 1; i < broadcasted.size(); ++i)
336 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
339 for (
size_t i = 0; i <
indices.size(); ++i) {
340 Value strideVal = strides[i];
341 Value offsetContrib =
342 arith::MulIOp::create(rewriter, loc,
indices[i], strideVal);
344 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
347 Value bcastBase = vector::BroadcastOp::create(
348 rewriter, loc, fullIndexVectorType, baseOffset);
349 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
360 typename = std::enable_if_t<llvm::is_one_of<
361 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
366 for (
size_t i = 0; i < offsets.size(); ++i) {
367 Value offsetContrib =
368 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
370 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
373 VectorType vecType = cast<VectorType>(
indices.getType());
376 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
378 Value stridedIndices =
379 arith::MulIOp::create(rewriter, loc, strideVector,
indices).getResult();
382 vector::BroadcastOp::create(
384 VectorType::get(vecType.getShape(), rewriter.
getIndexType()),
387 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
396static std::pair<Value, SmallVector<OpFoldResult>>
401 auto memrefType = cast<MemRefType>(
memref.getType());
402 unsigned rank = memrefType.getRank();
404 if (rank <= targetRank)
407 int64_t numCombinedDims = rank - targetRank;
413 for (
unsigned i = 0; i < numCombinedDims; ++i) {
414 subviewOffsets.push_back(offsets[i]);
421 auto originalShape = memrefType.getShape();
422 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc,
memref);
423 for (
unsigned i = numCombinedDims; i < rank; ++i) {
425 if (ShapedType::isDynamic(originalShape[i])) {
426 subviewSizes.push_back(meta.getSizes()[i]);
427 resultShape.push_back(ShapedType::kDynamic);
430 resultShape.push_back(originalShape[i]);
435 auto resultType = memref::SubViewOp::inferRankReducedResultType(
436 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
438 memref::SubViewOp::create(rewriter, loc, resultType,
memref,
439 subviewOffsets, subviewSizes, subviewStrides);
444 return {subviewOp.getResult(), newOffsets};
449 typename = std::enable_if_t<llvm::is_one_of<
450 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
451 vector::GatherOp, vector::ScatterOp>::value>>
455 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
456 rewriter, loc, xferOp.getBase())
458 return arith::IndexCastOp::create(rewriter, loc, rewriter.
getI64Type(),
463static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
467 VectorType vectorType = readOp.getVectorType();
469 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
473 auto meta = computeMemrefMeta(readOp, rewriter);
474 if (meta.first.empty())
478 computeOffsets(readOp, rewriter, meta.first, meta.second);
480 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
482 Value mask = vector::ConstantMaskOp::create(
485 auto gatherOp = xegpu::LoadGatherOp::create(
486 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
488 xegpu::CachePolicyAttr{},
489 xegpu::CachePolicyAttr{},
490 xegpu::CachePolicyAttr{},
493 rewriter.
replaceOp(readOp, gatherOp.getResult());
497static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
501 VectorType vectorType = writeOp.getVectorType();
504 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
508 auto meta = computeMemrefMeta(writeOp, rewriter);
509 if (meta.first.empty())
513 computeOffsets(writeOp, rewriter, meta.first, meta.second);
515 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
517 Value mask = vector::ConstantMaskOp::create(
520 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
523 xegpu::CachePolicyAttr{},
524 xegpu::CachePolicyAttr{},
525 xegpu::CachePolicyAttr{},
531struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
534 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
535 PatternRewriter &rewriter)
const override {
536 Location loc = readOp.getLoc();
538 if (
failed(transferPreconditions(rewriter, readOp)))
543 if (chip !=
"pvc" && chip !=
"bmg") {
547 if (readOp.hasOutOfBoundsDim())
549 return lowerToScatteredLoadOp(readOp, rewriter);
552 VectorType vecTy = readOp.getVectorType();
555 if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
556 return lowerToScatteredLoadOp(readOp, rewriter);
559 if (
failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
562 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
563 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
565 readOp,
"Unsupported non-zero padded out-of-bounds read");
567 AffineMap readMap = readOp.getPermutationMap();
570 Type elementType = vecTy.getElementType();
571 unsigned minTransposeBitWidth = 32;
572 if (isTransposeLoad &&
575 readOp,
"Unsupported data type for transposition");
578 SmallVector<int64_t> descShape(vecTy.getShape());
580 std::reverse(descShape.begin(), descShape.end());
581 auto descType = xegpu::TensorDescType::get(
582 descShape, elementType, 1,
583 isOutOfBounds, xegpu::MemorySpace::Global);
588 ArrayRef<int64_t>{1, 0});
589 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
593 xegpu::CachePolicyAttr hint =
nullptr;
594 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
597 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
598 nullptr, transposeAttr,
608struct TransferWriteLowering
612 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
613 PatternRewriter &rewriter)
const override {
614 Location loc = writeOp.getLoc();
616 if (
failed(transferPreconditions(rewriter, writeOp)))
621 if (chip !=
"pvc" && chip !=
"bmg") {
625 if (writeOp.hasOutOfBoundsDim())
627 return lowerToScatteredStoreOp(writeOp, rewriter);
631 VectorType vecTy = writeOp.getVectorType();
632 if (
failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
639 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
640 rewriter, loc, writeOp.getBase(),
643 auto descType = xegpu::TensorDescType::get(
644 vecTy.getShape(), vecTy.getElementType(),
645 1, writeOp.hasOutOfBoundsDim(),
646 xegpu::MemorySpace::Global);
648 xegpu::CachePolicyAttr hint =
nullptr;
649 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
652 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
666 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
667 PatternRewriter &rewriter)
const override {
668 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
672 Location loc = gatherOp.getLoc();
673 VectorType vectorType = gatherOp.getVectorType();
675 auto meta = computeMemrefMeta(gatherOp, rewriter);
676 if (meta.first.empty())
680 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
681 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
683 auto xeGatherOp = xegpu::LoadGatherOp::create(
684 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
686 xegpu::CachePolicyAttr{},
687 xegpu::CachePolicyAttr{},
688 xegpu::CachePolicyAttr{},
692 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
693 xeGatherOp.getResult(), gatherOp.getPassThru());
694 rewriter.
replaceOp(gatherOp, selectOp.getResult());
702 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
703 PatternRewriter &rewriter)
const override {
704 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
708 Location loc = scatterOp.getLoc();
709 auto meta = computeMemrefMeta(scatterOp, rewriter);
710 if (meta.first.empty())
712 "Failed to compute strides");
715 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
716 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
718 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
719 flatMemref, localOffsets, scatterOp.getMask(),
721 xegpu::CachePolicyAttr{},
722 xegpu::CachePolicyAttr{},
723 xegpu::CachePolicyAttr{},
733 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
734 PatternRewriter &rewriter)
const override {
735 Location loc = loadOp.getLoc();
737 VectorType vecTy = loadOp.getResult().getType();
738 if (
failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
742 bool boundaryCheck = vecTy.getRank() > 1;
744 xegpu::CachePolicyAttr hint =
nullptr;
746 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
750 auto descType = xegpu::TensorDescType::get(
751 vecTy.getShape(), vecTy.getElementType(), 1,
752 boundaryCheck, xegpu::MemorySpace::Global);
754 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
757 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
indices,
771 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
772 PatternRewriter &rewriter)
const override {
773 Location loc = storeOp.getLoc();
776 VectorType vecTy = vector.getType();
777 if (
failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
781 bool boundaryCheck = vecTy.getRank() > 1;
783 auto [src,
indices] = convertMemrefAndOffsetsToTargetRank(
784 rewriter, loc, storeOp.getBase(),
787 auto descType = xegpu::TensorDescType::get(
788 vecTy.getShape(), vecTy.getElementType(),
789 1, boundaryCheck, xegpu::MemorySpace::Global);
792 xegpu::CachePolicyAttr hint =
nullptr;
793 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
797 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
indices,
808struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
811 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
812 PatternRewriter &rewriter)
const override {
813 Location loc = contractOp.getLoc();
815 if (contractOp.getKind() != vector::CombiningKind::ADD)
817 "Expects add combining kind");
820 VectorType accType = dyn_cast<VectorType>(acc.getType());
821 if (!accType || accType.getRank() != 2)
828 if (
lhs.getType().getRank() != 2 ||
rhs.getType().getRank() != 2)
830 "Expects lhs and rhs 2D vectors");
835 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
844struct ConvertVectorToXeGPUPass
846 void runOnOperation()
override {
850 return signalPassFailure();
859 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
860 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...