29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
35#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
36#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
40#define DEBUG_TYPE "xegpu-optimize-peephole"
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
48static std::optional<SmallVector<int64_t>>
49getMaybeLaneData(xegpu::TensorDescType tdescType) {
50 auto layout = tdescType.getLayoutAttr();
53 auto laneData = layout.getEffectiveLaneDataAsInt();
54 if (laneData.size() != 2)
60static std::optional<SmallVector<int64_t>>
61getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
62 auto layout = tdescType.getLayoutAttr();
65 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
66 if (laneLayout.size() != 2)
84 if (laneLayout.size() != 2 || laneData.size() != 2)
86 if (laneLayout[0] == 1 || laneLayout[1] != 1)
88 if (laneData[0] != 1 || laneData[1] == 1)
95static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
97 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
98 if (elementTyBitwidth >= 32)
100 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
101 auto maybeLaneData = getMaybeLaneData(tdescType);
102 if (!maybeLaneData || !maybeLaneLayout)
104 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
109static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
110 const uArch *targetuArch) {
111 if (!canBeOptimizedForTranspose(tdescType))
113 auto laneData = getMaybeLaneData(tdescType)
115 int64_t innerLaneData = laneData[1];
116 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
121 requiredShape.back() =
122 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
123 int newBitWidth = elementTyBitwidth * innerLaneData;
124 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
127 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
128 targetuArch->
getInstruction(InstructionKind::Subgroup2DBlockLoad));
129 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
130 newElemTy,
false,
true);
134 auto [widths, heights, counts] = maybeHWParams.value();
136 if (counts.size() != 1 || counts[0] != 1)
138 int arrayLen = counts[0];
139 int supportedHeight =
144 if (supportedHeight == -1 || supportedWidth == -1)
148 auto ctx = tdescType.getContext();
149 auto origLayout = tdescType.getLayoutAttr();
150 auto laneLayoutI64 = origLayout.getEffectiveLaneLayoutAsInt();
152 laneLayoutI64.end());
154 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
157 origLayout.getOrder());
160 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
161 tdescType.getBoundaryCheck(),
162 tdescType.getMemorySpace(), newLayout);
166static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
171 return llvm::cast<Value>(ofr);
175static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
178 if (llvm::isPowerOf2_64(constant)) {
179 int64_t shiftAmount = llvm::Log2_64(constant);
180 return arith::ShRUIOp::create(
188 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
194static Value generateLoads(ConversionPatternRewriter &rewriter,
198 xegpu::LoadNdOp origLoadOp) {
200 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
201 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
202 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
209 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
210 for (
int64_t w = 0; w < shapeRatio[1]; ++w) {
211 int64_t localOffsetDim0 = h * supportedShape[0];
212 int64_t localOffsetDim1 = w * supportedShape[1];
213 Value loadOffsetX = arith::AddIOp::create(
214 rewriter, loc, offsetDim0,
217 Value loadOffsetY = arith::AddIOp::create(
218 rewriter, loc, offsetDim1,
221 auto loadOp = xegpu::LoadNdOp::create(
223 VectorType::get(supportedShape, data.getType().getElementType()),
225 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
226 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
227 origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
229 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
230 loadOp.setAnchorLayout(layoutAttr);
232 auto insertOp = vector::InsertStridedSliceOp::create(
233 rewriter, loc, loadOp.getResult(), data,
238 data = insertOp.getResult();
248class XeGPUCreateNdDescOpPattern final
249 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
251 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
253 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override {
255 auto tdescTy = createNdOp.getType();
260 (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg" ||
261 chipStr.value() ==
"cri") &&
262 "Expecting target chip to be pvc, bmg or cri for transpose "
266 auto convertType = tryOptimize(tdescTy, targetuArch);
267 if (convertType == tdescTy)
269 auto strides = createNdOp.getMixedStrides();
272 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
273 return rewriter.notifyMatchFailure(
274 createNdOp,
"Expecting row-major memref for transpose optimization.");
275 Value source = createNdOp.getSource();
276 auto optionalLaneData = getMaybeLaneData(tdescTy);
277 assert(optionalLaneData &&
"Expected 2D lane data");
278 auto laneData = optionalLaneData.value();
279 int64_t innerLaneData = laneData[1];
280 auto memrefType = dyn_cast<MemRefType>(source.
getType());
283 modifiedShape.back() = divideByConstant(
284 rewriter, createNdOp.getLoc(),
285 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
288 assert(strides.size() >= 2 &&
289 "Expected at least 2 strides for CreateNdDescOp");
291 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
292 rewriter, createNdOp.getLoc(),
293 convertToValue(rewriter, createNdOp.getLoc(),
294 modifiedStrides[modifiedStrides.size() - 2]),
299 if (memrefType && memrefType.hasStaticShape()) {
300 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
301 rewriter, createNdOp.getLoc(), source);
302 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
303 rewriter.getI64Type(),
304 extractOp.getResult())
308 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
309 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
311 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
320class XeGPULoadNdDescOpPattern final
321 :
public OpConversionPattern<xegpu::LoadNdOp> {
323 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
325 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter)
const override {
327 auto origTensorDescType = loadNdOp.getTensorDescType();
329 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
330 if (adaptorType == origTensorDescType)
333 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
334 int64_t innerLaneData = laneData[1];
335 auto offsets = loadNdOp.getMixedOffsets();
337 return rewriter.notifyMatchFailure(loadNdOp,
338 "Expecting offsets in LoadNd");
340 modifiedOffsets.back() = divideByConstant(
341 rewriter, loadNdOp.getLoc(),
342 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
348 origDataShape.back() /= innerLaneData;
351 VectorType origVectorType =
352 VectorType::get(origDataShape, adaptorType.getElementType());
355 if (origTensorDescType.getArrayLength() > 1) {
357 for (
int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
358 Value slice = arith::ConstantOp::create(
359 rewriter, loadNdOp->getLoc(), origVectorType,
360 rewriter.getZeroAttr(origVectorType));
362 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
363 modifiedOffsets.back());
364 modifiedOffsets.back() =
365 arith::AddIOp::create(
366 rewriter, loadNdOp->getLoc(), offsetY,
368 i * origDataShape[1])
371 slice = generateLoads(
376 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
377 origTensorDescType.getElementType());
378 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
382 origTensorDescType.getLayoutAttr());
383 arraySlices.push_back(bitCastOp.getResult());
385 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
388 data = arith::ConstantOp::create(
389 rewriter, loadNdOp->getLoc(),
390 VectorType::get(origDataShape, adaptorType.getElementType()),
391 rewriter.getZeroAttr(origVectorType));
392 data = generateLoads(
396 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
397 loadNdOp.getType(), data);
400 origTensorDescType.getLayoutAttr());
401 rewriter.replaceOp(loadNdOp, bitCastOp);
410class VectorExtractOpPattern final
411 :
public OpConversionPattern<vector::ExtractOp> {
413 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
416 ConversionPatternRewriter &rewriter)
const override {
418 if (adaptor.getSource().size() == 1)
420 auto mixedPos = extractOp.getMixedPosition();
421 if (mixedPos.size() != 1)
426 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
433class MultiRed2dOpPattern
434 :
public OpConversionPattern<vector::MultiDimReductionOp> {
435 using OpConversionPattern::OpConversionPattern;
437 matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
438 ConversionPatternRewriter &rewriter)
const override {
439 auto sourceVecType = reductionOp.getSourceVectorType();
440 if (reductionOp.getReductionDims().size() != 2)
441 return rewriter.notifyMatchFailure(reductionOp,
"Expected 2D reduction");
444 auto dims = llvm::to_vector(reductionOp.getReductionDims());
445 auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
447 if (intraLaneDim == -1 || crossLaneDim == -1) {
448 intraLaneDim = dims[0];
449 crossLaneDim = dims[1];
451 auto loc = reductionOp.getLoc();
452 auto acc = reductionOp.getAcc();
464 xegpu::DistributeLayoutAttr postDecompLayout;
467 xegpu::DistributeLayoutAttr srcLayoutForCvt;
468 if (
auto resSlice = dyn_cast_if_present<xegpu::SliceAttr>(resLayout))
469 srcLayoutForCvt = resSlice.getParent();
470 if (!srcLayoutForCvt)
473 if (srcLayoutForCvt) {
481 crossLaneDim > intraLaneDim ? crossLaneDim - 1 : crossLaneDim;
482 auto intermediateLayout = xegpu::SliceAttr::get(
484 postDecompLayout = xegpu::SliceAttr::get(
485 ctx, intermediateLayout,
491 accShape.erase(accShape.begin() + intraLaneDim);
492 Type eTy = sourceVecType.getElementType();
494 rewriter, loc, VectorType::get(accShape, eTy), reductionOp.getKind());
496 Value intraLaneReduced = vector::MultiDimReductionOp::create(
497 rewriter, loc, reductionOp.getKind(), reductionOp.getSource(),
501 if (crossLaneDim > intraLaneDim)
503 Value crossLaneReduced = vector::MultiDimReductionOp::create(
504 rewriter, loc, reductionOp.getKind(), intraLaneReduced,
acc,
506 assert(crossLaneReduced.
getType() == reductionOp.getResult().getType() &&
510 if (resLayout && postDecompLayout) {
516 auto bridgeOp = xegpu::ConvertLayoutOp::create(
517 rewriter, loc, crossLaneReduced.
getType(), crossLaneReduced,
518 postDecompLayout, resLayout);
527 std::pair<int64_t, int64_t>
529 xegpu::DistributeLayoutAttr layout)
const {
530 assert(layout.isForSubgroup() &&
"Must know the lane layout");
531 assert(reductionDims.size() == 2 &&
"Expected 2D reduction");
533 xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
534 if (
auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
536 dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
540 assert(laneLayout.size() &&
"Expected a non-empty layout");
542 for (
auto dim : reductionDims) {
543 if (laneLayout[dim] == 1)
548 return {intra, cross};
556 patterns.
add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
557 VectorExtractOpPattern, MultiRed2dOpPattern>(
563struct XeGPUPeepHoleOptimizerPass final
564 :
public xegpu::impl::XeGPUPeepHoleOptimizerBase<
565 XeGPUPeepHoleOptimizerPass> {
566 void runOnOperation()
override {
574 bool isTargetSupported =
false;
575 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
577 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg" ||
578 chipStr.value() ==
"cri"))
579 isTargetSupported =
true;
582 if (!isTargetSupported) {
583 DBGS() <<
"XeGPUPeepHoleOptimizerPass only supports PVC, BMG targets."
591 RewritePatternSet arrayLenPatterns(&context);
594 std::move(arrayLenPatterns)))) {
595 DBGS() <<
"Array length optimization patterns failed.\n";
596 return signalPassFailure();
602 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
603 [&](xegpu::CreateNdDescOp createNdOp) {
604 return !canBeOptimizedForTranspose(createNdOp.getType());
606 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
607 [&](xegpu::LoadNdOp loadNdOp) {
608 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
613 target.addDynamicallyLegalOp<vector::ExtractOp>(
614 [&](vector::ExtractOp extractOp) {
616 dyn_cast<OpResult>(extractOp.getResult()));
619 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
620 auto laneData = layout.getEffectiveLaneDataAsInt();
621 return !canBeOptimizedForTranspose(laneLayout, laneData);
624 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
625 [=](Operation *op) ->
bool {
627 if (!layout || !layout.isForSubgroup())
629 if (
auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
630 return reductionOp.getReductionDims().size() != 2;
634 converter.addConversion([](Type type) {
return type; });
636 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
637 vector::VectorDialect>();
640 target.addLegalOp<xegpu::ConvertLayoutOp>();
644 if (
failed(applyPartialConversion(getOperation(),
target,
645 std::move(patterns)))) {
646 DBGS() <<
"Optimize block loads pass failed.\n";
647 return signalPassFailure();
652 RewritePatternSet emptyPatterns(ctx);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUArrayLengthOptimizationPatterns(RewritePatternSet &patterns)
Appends patterns for array length optimization into patterns.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUPeepHoleOptimizerPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
const Instruction * getInstruction(InstructionKind instKind) const