24 #include "llvm/ADT/TypeSwitch.h"
30 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31 #include "mlir/Conversion/Passes.h.inc"
39 static bool isZeroConstant(
Value val) {
46 [](
auto floatAttr) {
return floatAttr.getValue().isZero(); })
48 [](
auto intAttr) {
return intAttr.getValue().isZero(); })
49 .Default([](
auto) {
return false; });
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
64 VectorTransferOpInterface xferOp) {
67 "Masked transfer is not supported");
69 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
76 if (
failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
78 xferOp,
"Buffer must be contiguous in the innermost dimension");
80 VectorType vecTy = xferOp.getVectorType();
81 unsigned vecRank = vecTy.getRank();
82 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
84 xferOp,
"Boundary check is available only for block instructions.");
91 auto dim = dyn_cast<AffineDimExpr>(expr);
92 if (dim.getPosition() < (numInputDims - vecRank))
94 xferOp,
"Only the innermost dimensions can be accessed");
100 static xegpu::CreateNdDescOp
104 MemRefType srcTy = src.getType();
105 auto [strides, offset] = srcTy.getStridesAndOffset();
107 xegpu::CreateNdDescOp ndDesc;
108 if (srcTy.hasStaticShape()) {
109 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
115 unsigned srcRank = srcTy.getRank();
116 for (
unsigned i = 0; i < srcRank; ++i)
117 sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
121 for (
Value offset : offsets) {
124 dynOffsets.push_back(offset);
125 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
130 if (shape == ShapedType::kDynamic)
131 dynShapes.push_back(sourceDims[idx]);
138 for (
int i =
static_cast<int>(strides.size()) - 2; i >= 0; --i) {
140 arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
141 if (strides[i] == ShapedType::kDynamic)
142 dynStrides.push_back(accStride);
144 std::reverse(dynStrides.begin(), dynStrides.end());
146 ndDesc = xegpu::CreateNdDescOp::create(
147 rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
173 static void adjustStridesForPermutation(
AffineMap permMap,
189 Value baseMemref = xferOp.getBase();
191 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
194 if (memrefType.hasStaticShape()) {
197 if (
failed(memrefType.getStridesAndOffset(intStrides, offset)))
200 for (int64_t s : intStrides)
205 unsigned rank = memrefType.getRank();
212 {}, memrefType.getElementType()));
213 resultTypes.push_back(indexType);
215 for (
unsigned i = 0; i < rank; ++i)
216 resultTypes.push_back(indexType);
218 for (
unsigned i = 0; i < rank; ++i)
219 resultTypes.push_back(indexType);
221 auto meta = memref::ExtractStridedMetadataOp::create(
222 rewriter, loc, resultTypes, baseMemref);
223 strides.append(meta.getStrides().begin(), meta.getStrides().end());
226 adjustStridesForPermutation(permMap, strides);
258 static Value computeOffsets(VectorTransferOpInterface xferOp,
262 VectorType vectorType = xferOp.getVectorType();
264 xferOp.getIndices().end());
269 llvm::map_to_vector(
vectorShape, [&](int64_t dim) {
271 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
272 stepVectors.push_back(stepOp);
277 size_t memrefRank = strides.size();
280 for (
size_t i = 0; i < vectorRank; ++i) {
281 size_t memrefDim = memrefRank - vectorRank + i;
282 Value strideValue = strides[memrefDim];
283 auto mulType = dyn_cast<VectorType>(stepVectors[i].
getType());
285 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
286 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
287 strideMultiplied.push_back(mulOp);
292 for (
size_t i = 0; i < vectorRank; ++i) {
296 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
297 strideMultiplied[i]);
298 shapeCasted.push_back(castOp);
303 auto fullIndexVectorType =
305 for (
Value shapeCastVal : shapeCasted) {
306 auto broadcastOp = vector::BroadcastOp::create(
307 rewriter, loc, fullIndexVectorType, shapeCastVal);
308 broadcasted.push_back(broadcastOp);
312 Value localOffsets = broadcasted[0];
313 for (
size_t i = 1; i < broadcasted.size(); ++i)
315 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
318 Value baseOffset =
nullptr;
319 if (!indices.empty()) {
321 for (
size_t i = 0; i < indices.size(); ++i) {
322 Value strideVal = strides[i];
323 Value offsetContrib =
324 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
326 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
329 Value bcastBase = vector::BroadcastOp::create(
330 rewriter, loc, fullIndexVectorType, baseOffset);
332 arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
338 static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
342 Value baseMemref = xferOp.getBase();
343 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.
getType());
344 Type elementType = memrefType.getElementType();
347 MemRefType flatMemrefType;
348 if (memrefType.hasStaticShape()) {
349 auto totalElements = memrefType.getNumElements();
357 llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
358 reassociation.push_back(allDims);
360 auto collapseOp = memref::CollapseShapeOp::create(
361 rewriter, loc, flatMemrefType, baseMemref, reassociation);
365 static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
369 VectorType vectorType = readOp.getVectorType();
371 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
379 Value localOffsets = computeOffsets(readOp, rewriter, strides);
381 Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
383 Value mask = vector::ConstantMaskOp::create(
386 auto gatherOp = xegpu::LoadGatherOp::create(
387 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
389 xegpu::CachePolicyAttr{},
390 xegpu::CachePolicyAttr{},
391 xegpu::CachePolicyAttr{});
393 rewriter.
replaceOp(readOp, gatherOp.getResult());
397 static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
401 VectorType vectorType = writeOp.getVectorType();
404 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
410 Value localOffsets = computeOffsets(writeOp, rewriter, strides);
412 Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
414 Value mask = vector::ConstantMaskOp::create(
417 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
420 xegpu::CachePolicyAttr{},
421 xegpu::CachePolicyAttr{},
422 xegpu::CachePolicyAttr{});
427 struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
430 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
434 if (
failed(transferPreconditions(rewriter, readOp)))
439 if (chip !=
"pvc" && chip !=
"bmg") {
443 if (readOp.hasOutOfBoundsDim())
445 return lowerToScatteredLoadOp(readOp, rewriter);
449 VectorType vecTy = readOp.getVectorType();
450 if (
failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
453 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
454 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
456 readOp,
"Unsupported non-zero padded out-of-bounds read");
461 Type elementType = vecTy.getElementType();
462 unsigned minTransposeBitWidth = 32;
463 if (isTransposeLoad &&
466 readOp,
"Unsupported data type for transposition");
471 std::reverse(descShape.begin(), descShape.end());
473 descShape, elementType, 1,
474 isOutOfBounds, xegpu::MemorySpace::Global);
476 xegpu::CreateNdDescOp ndDesc =
477 createNdDescriptor(rewriter, loc, descType,
479 readOp.getIndices());
482 !isTransposeLoad ? nullptr
486 xegpu::CachePolicyAttr hint =
nullptr;
487 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
488 nullptr, transposeAttr,
497 struct TransferWriteLowering
501 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
505 if (
failed(transferPreconditions(rewriter, writeOp)))
510 if (chip !=
"pvc" && chip !=
"bmg") {
514 if (writeOp.hasOutOfBoundsDim())
516 return lowerToScatteredStoreOp(writeOp, rewriter);
520 VectorType vecTy = writeOp.getVectorType();
521 if (
failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
529 vecTy.getShape(), vecTy.getElementType(),
530 1, writeOp.hasOutOfBoundsDim(),
531 xegpu::MemorySpace::Global);
532 xegpu::CreateNdDescOp ndDesc =
533 createNdDescriptor(rewriter, loc, descType,
535 writeOp.getIndices());
538 xegpu::CachePolicyAttr hint =
nullptr;
540 xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
552 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
556 VectorType vecTy = loadOp.getResult().getType();
557 if (
failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
561 bool boundaryCheck = vecTy.getRank() > 1;
564 vecTy.getShape(), vecTy.getElementType(), 1,
565 boundaryCheck, xegpu::MemorySpace::Global);
566 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
567 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
570 xegpu::CachePolicyAttr hint =
nullptr;
571 auto loadNdOp = xegpu::LoadNdOp::create(
572 rewriter, loc, vecTy, ndDesc,
nullptr,
nullptr,
584 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
589 VectorType vecTy = vector.getType();
590 if (
failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
594 bool boundaryCheck = vecTy.getRank() > 1;
597 vecTy.getShape(), vecTy.getElementType(),
598 1, boundaryCheck, xegpu::MemorySpace::Global);
599 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
600 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
603 xegpu::CachePolicyAttr hint =
nullptr;
605 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
614 struct ContractionLowering :
public OpRewritePattern<vector::ContractionOp> {
617 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
621 if (contractOp.getKind() != vector::CombiningKind::ADD)
623 "Expects add combining kind");
626 VectorType accType = dyn_cast<VectorType>(acc.getType());
627 if (!accType || accType.getRank() != 2)
634 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
636 "Expects lhs and rhs 2D vectors");
641 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
650 struct ConvertVectorToXeGPUPass
651 :
public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
652 void runOnOperation()
override {
656 return signalPassFailure();
664 patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
665 StoreLowering, ContractionLowering>(
patterns.getContext());
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...