29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
35#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
36#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
40#define DEBUG_TYPE "xegpu-optimize-peephole"
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
48static std::optional<SmallVector<int64_t>>
49getMaybeLaneData(xegpu::TensorDescType tdescType) {
50 auto layout = tdescType.getLayoutAttr();
53 auto laneData = layout.getEffectiveLaneDataAsInt();
54 if (laneData.size() != 2)
60static std::optional<SmallVector<int64_t>>
61getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
62 auto layout = tdescType.getLayoutAttr();
65 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
66 if (laneLayout.size() != 2)
84 if (laneLayout.size() != 2 || laneData.size() != 2)
86 if (laneLayout[0] == 1 || laneLayout[1] != 1)
88 if (laneData[0] != 1 || laneData[1] == 1)
95static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
97 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
98 if (elementTyBitwidth >= 32)
100 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
101 auto maybeLaneData = getMaybeLaneData(tdescType);
102 if (!maybeLaneData || !maybeLaneLayout)
104 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
109static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
110 const uArch *targetuArch) {
111 if (!canBeOptimizedForTranspose(tdescType))
113 auto laneData = getMaybeLaneData(tdescType)
116 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
121 requiredShape.back() =
122 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
123 int newBitWidth = elementTyBitwidth * innerLaneData;
124 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
127 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
129 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
130 newElemTy,
false,
true);
134 auto [widths, heights, counts] = maybeHWParams.value();
136 if (counts.size() != 1 || counts[0] != 1)
138 int arrayLen = counts[0];
139 int supportedHeight =
144 if (supportedHeight == -1 || supportedWidth == -1)
148 auto ctx = tdescType.getContext();
149 auto origLayout = tdescType.getLayoutAttr();
150 auto laneLayoutI64 = origLayout.getEffectiveLaneLayoutAsInt();
152 laneLayoutI64.end());
154 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
157 origLayout.getOrder());
160 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
161 tdescType.getBoundaryCheck(),
162 tdescType.getMemorySpace(), newLayout);
166static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
171 return llvm::cast<Value>(ofr);
175static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
178 if (llvm::isPowerOf2_64(constant)) {
179 int64_t shiftAmount = llvm::Log2_64(constant);
180 return arith::ShRUIOp::create(
188 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
194static Value generateLoads(ConversionPatternRewriter &rewriter,
198 xegpu::LoadNdOp origLoadOp) {
200 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
201 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
202 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
209 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
210 for (
int64_t w = 0; w < shapeRatio[1]; ++w) {
211 int64_t localOffsetDim0 = h * supportedShape[0];
212 int64_t localOffsetDim1 = w * supportedShape[1];
213 Value loadOffsetX = arith::AddIOp::create(
214 rewriter, loc, offsetDim0,
217 Value loadOffsetY = arith::AddIOp::create(
218 rewriter, loc, offsetDim1,
221 auto loadOp = xegpu::LoadNdOp::create(
223 VectorType::get(supportedShape, data.getType().getElementType()),
225 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
226 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
227 origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
229 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
230 loadOp.setAnchorLayout(layoutAttr);
232 auto insertOp = vector::InsertStridedSliceOp::create(
233 rewriter, loc, loadOp.getResult(), data,
238 data = insertOp.getResult();
248class XeGPUCreateNdDescOpPattern final
249 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
251 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
253 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override {
255 auto tdescTy = createNdOp.getType();
260 chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg") &&
261 "Expecting target chip to be pvc or bmg for transpose optimization.");
264 auto convertType = tryOptimize(tdescTy, targetuArch);
265 if (convertType == tdescTy)
267 auto strides = createNdOp.getMixedStrides();
270 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
271 return rewriter.notifyMatchFailure(
272 createNdOp,
"Expecting row-major memref for transpose optimization.");
273 Value source = createNdOp.getSource();
274 auto optionalLaneData = getMaybeLaneData(tdescTy);
275 assert(optionalLaneData &&
"Expected 2D lane data");
276 auto laneData = optionalLaneData.value();
277 int64_t innerLaneData = laneData[1];
278 auto memrefType = dyn_cast<MemRefType>(source.
getType());
280 SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
281 modifiedShape.back() = divideByConstant(
282 rewriter, createNdOp.getLoc(),
283 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
286 assert(strides.size() >= 2 &&
287 "Expected at least 2 strides for CreateNdDescOp");
288 SmallVector<OpFoldResult> modifiedStrides(strides);
289 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
290 rewriter, createNdOp.getLoc(),
291 convertToValue(rewriter, createNdOp.getLoc(),
292 modifiedStrides[modifiedStrides.size() - 2]),
297 if (memrefType && memrefType.hasStaticShape()) {
298 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
299 rewriter, createNdOp.getLoc(), source);
300 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
301 rewriter.getI64Type(),
302 extractOp.getResult())
306 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
307 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
309 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
318class XeGPULoadNdDescOpPattern final
319 :
public OpConversionPattern<xegpu::LoadNdOp> {
321 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
323 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
324 ConversionPatternRewriter &rewriter)
const override {
325 auto origTensorDescType = loadNdOp.getTensorDescType();
327 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
328 if (adaptorType == origTensorDescType)
331 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
332 int64_t innerLaneData = laneData[1];
333 auto offsets = loadNdOp.getMixedOffsets();
335 return rewriter.notifyMatchFailure(loadNdOp,
336 "Expecting offsets in LoadNd");
337 SmallVector<OpFoldResult> modifiedOffsets(offsets);
338 modifiedOffsets.back() = divideByConstant(
339 rewriter, loadNdOp.getLoc(),
340 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
344 SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
346 origDataShape.back() /= innerLaneData;
348 SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
349 VectorType origVectorType =
350 VectorType::get(origDataShape, adaptorType.getElementType());
353 if (origTensorDescType.getArrayLength() > 1) {
354 SmallVector<Value> arraySlices;
355 for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
356 Value slice = arith::ConstantOp::create(
357 rewriter, loadNdOp->getLoc(), origVectorType,
358 rewriter.getZeroAttr(origVectorType));
360 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
361 modifiedOffsets.back());
362 modifiedOffsets.back() =
363 arith::AddIOp::create(
364 rewriter, loadNdOp->getLoc(), offsetY,
366 i * origDataShape[1])
369 slice = generateLoads(
374 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
375 origTensorDescType.getElementType());
376 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
380 origTensorDescType.getLayoutAttr());
381 arraySlices.push_back(bitCastOp.getResult());
383 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
386 data = arith::ConstantOp::create(
387 rewriter, loadNdOp->getLoc(),
388 VectorType::get(origDataShape, adaptorType.getElementType()),
389 rewriter.getZeroAttr(origVectorType));
390 data = generateLoads(
394 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
395 loadNdOp.getType(), data);
398 origTensorDescType.getLayoutAttr());
399 rewriter.replaceOp(loadNdOp, bitCastOp);
408class VectorExtractOpPattern final
409 :
public OpConversionPattern<vector::ExtractOp> {
411 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
413 matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
414 ConversionPatternRewriter &rewriter)
const override {
416 if (adaptor.getSource().size() == 1)
418 auto mixedPos = extractOp.getMixedPosition();
419 if (mixedPos.size() != 1)
424 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
431class MultiRed2dOpPattern
432 :
public OpConversionPattern<vector::MultiDimReductionOp> {
433 using OpConversionPattern::OpConversionPattern;
435 matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
436 ConversionPatternRewriter &rewriter)
const override {
437 auto sourceVecType = reductionOp.getSourceVectorType();
438 if (reductionOp.getReductionDims().size() != 2)
439 return rewriter.notifyMatchFailure(reductionOp,
"Expected 2D reduction");
442 auto dims = llvm::to_vector(reductionOp.getReductionDims());
443 auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
445 if (intraLaneDim == -1 || crossLaneDim == -1) {
446 intraLaneDim = dims[0];
447 crossLaneDim = dims[1];
449 auto loc = reductionOp.getLoc();
450 auto acc = reductionOp.getAcc();
462 xegpu::DistributeLayoutAttr postDecompLayout;
465 xegpu::DistributeLayoutAttr srcLayoutForCvt;
466 if (
auto resSlice = dyn_cast_if_present<xegpu::SliceAttr>(resLayout))
467 srcLayoutForCvt = resSlice.getParent();
468 if (!srcLayoutForCvt)
471 if (srcLayoutForCvt) {
477 MLIRContext *ctx = reductionOp.getContext();
478 int64_t adjCrossLaneDim =
479 crossLaneDim > intraLaneDim ? crossLaneDim - 1 : crossLaneDim;
480 auto intermediateLayout = xegpu::SliceAttr::get(
482 postDecompLayout = xegpu::SliceAttr::get(
483 ctx, intermediateLayout,
488 SmallVector<int64_t> accShape(sourceVecType.getShape());
489 accShape.erase(accShape.begin() + intraLaneDim);
490 Type eTy = sourceVecType.getElementType();
492 rewriter, loc, VectorType::get(accShape, eTy), reductionOp.getKind());
494 Value intraLaneReduced = vector::MultiDimReductionOp::create(
495 rewriter, loc, reductionOp.getKind(), reductionOp.getSource(),
496 constNeutralVal, ArrayRef<int64_t>(intraLaneDim));
499 if (crossLaneDim > intraLaneDim)
501 Value crossLaneReduced = vector::MultiDimReductionOp::create(
502 rewriter, loc, reductionOp.getKind(), intraLaneReduced, acc,
503 ArrayRef<int64_t>(crossLaneDim));
504 assert(crossLaneReduced.
getType() == reductionOp.getResult().getType() &&
508 if (resLayout && postDecompLayout) {
514 auto bridgeOp = xegpu::ConvertLayoutOp::create(
515 rewriter, loc, crossLaneReduced.
getType(), crossLaneReduced,
516 postDecompLayout, resLayout);
525 std::pair<int64_t, int64_t>
526 getReductionDimOrder(ArrayRef<int64_t> reductionDims,
527 xegpu::DistributeLayoutAttr layout)
const {
528 assert(layout.isForSubgroup() &&
"Must know the lane layout");
529 assert(reductionDims.size() == 2 &&
"Expected 2D reduction");
530 int64_t intra, cross = -1;
531 xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
532 if (
auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
534 dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
536 SmallVector<int64_t> laneLayout = layoutAttr.getEffectiveLaneLayoutAsInt();
538 assert(laneLayout.size() &&
"Expected a non-empty layout");
540 for (
auto dim : reductionDims) {
541 if (laneLayout[dim] == 1)
546 return {intra, cross};
554 patterns.
add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
555 VectorExtractOpPattern, MultiRed2dOpPattern>(
561struct XeGPUPeepHoleOptimizerPass final
563 XeGPUPeepHoleOptimizerPass> {
564 void runOnOperation()
override {
572 bool isTargetSupported =
false;
573 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
575 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg"))
576 isTargetSupported =
true;
579 if (!isTargetSupported) {
580 DBGS() <<
"XeGPUPeepHoleOptimizerPass only supports PVC and BMG targets."
588 RewritePatternSet arrayLenPatterns(&context);
591 std::move(arrayLenPatterns)))) {
592 DBGS() <<
"Array length optimization patterns failed.\n";
593 return signalPassFailure();
599 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
600 [&](xegpu::CreateNdDescOp createNdOp) {
601 return !canBeOptimizedForTranspose(createNdOp.getType());
603 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
604 [&](xegpu::LoadNdOp loadNdOp) {
605 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
610 target.addDynamicallyLegalOp<vector::ExtractOp>(
611 [&](vector::ExtractOp extractOp) {
613 dyn_cast<OpResult>(extractOp.getResult()));
616 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
617 auto laneData = layout.getEffectiveLaneDataAsInt();
618 return !canBeOptimizedForTranspose(laneLayout, laneData);
621 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
622 [=](Operation *op) ->
bool {
624 if (!layout || !layout.isForSubgroup())
626 if (
auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
627 return reductionOp.getReductionDims().size() != 2;
631 converter.addConversion([](Type type) {
return type; });
633 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
634 vector::VectorDialect>();
637 target.addLegalOp<xegpu::ConvertLayoutOp>();
641 if (
failed(applyPartialConversion(getOperation(),
target,
642 std::move(patterns)))) {
643 DBGS() <<
"Optimize block loads pass failed.\n";
644 return signalPassFailure();
649 RewritePatternSet emptyPatterns(ctx);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUArrayLengthOptimizationPatterns(RewritePatternSet &patterns)
Appends patterns for array length optimization into patterns.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUPeepHoleOptimizerPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
const Instruction * getInstruction(InstructionKind instKind) const