29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
35#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
36#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
40#define DEBUG_TYPE "xegpu-optimize-peephole"
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
48static std::optional<SmallVector<int64_t>>
49getMaybeLaneData(xegpu::TensorDescType tdescType) {
50 auto layout = tdescType.getLayoutAttr();
53 auto laneData = layout.getEffectiveLaneDataAsInt();
54 if (laneData.size() != 2)
60static std::optional<SmallVector<int64_t>>
61getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
62 auto layout = tdescType.getLayoutAttr();
65 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
66 if (laneLayout.size() != 2)
84 if (laneLayout.size() != 2 || laneData.size() != 2)
86 if (laneLayout[0] == 1 || laneLayout[1] != 1)
88 if (laneData[0] != 1 || laneData[1] == 1)
95static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
97 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
98 if (elementTyBitwidth >= 32)
100 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
101 auto maybeLaneData = getMaybeLaneData(tdescType);
102 if (!maybeLaneData || !maybeLaneLayout)
104 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
109static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
111 if (!canBeOptimizedForTranspose(tdescType))
113 auto laneData = getMaybeLaneData(tdescType)
116 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
121 requiredShape.back() =
122 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
123 int newBitWidth = elementTyBitwidth * innerLaneData;
124 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
127 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
129 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
130 newElemTy,
false,
true);
134 auto [widths, heights, counts] = maybeHWParams.value();
136 if (counts.size() != 1 || counts[0] != 1)
138 int arrayLen = counts[0];
139 int supportedHeight =
144 if (supportedHeight == -1 || supportedWidth == -1)
148 auto ctx = tdescType.getContext();
149 auto origLayout = tdescType.getLayoutAttr();
150 auto laneLayoutI64 = origLayout.getEffectiveLaneLayoutAsInt();
152 laneLayoutI64.end());
154 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
157 origLayout.getOrder());
160 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
161 tdescType.getBoundaryCheck(),
162 tdescType.getMemorySpace(), newLayout);
166static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
171 return llvm::cast<Value>(ofr);
175static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
178 if (llvm::isPowerOf2_64(constant)) {
179 int64_t shiftAmount = llvm::Log2_64(constant);
180 return arith::ShRUIOp::create(
188 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
194static Value generateLoads(ConversionPatternRewriter &rewriter,
198 xegpu::LoadNdOp origLoadOp) {
200 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
201 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
202 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
209 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
210 for (
int64_t w = 0; w < shapeRatio[1]; ++w) {
211 int64_t localOffsetDim0 = h * supportedShape[0];
212 int64_t localOffsetDim1 = w * supportedShape[1];
213 Value loadOffsetX = arith::AddIOp::create(
214 rewriter, loc, offsetDim0,
217 Value loadOffsetY = arith::AddIOp::create(
218 rewriter, loc, offsetDim1,
221 auto loadOp = xegpu::LoadNdOp::create(
223 VectorType::get(supportedShape, data.getType().getElementType()),
225 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
226 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
227 origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
229 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
230 loadOp.setAnchorLayout(layoutAttr);
232 auto insertOp = vector::InsertStridedSliceOp::create(
233 rewriter, loc, loadOp.getResult(), data,
238 data = insertOp.getResult();
248class XeGPUCreateNdDescOpPattern final
249 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
251 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
253 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override {
255 auto tdescTy = createNdOp.getType();
259 assert(chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg") &&
260 "Expecting target chip to be pvc, bmg for transpose optimization.");
263 auto convertType = tryOptimize(tdescTy, targetuArch);
264 if (convertType == tdescTy)
266 auto strides = createNdOp.getMixedStrides();
269 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
270 return rewriter.notifyMatchFailure(
271 createNdOp,
"Expecting row-major memref for transpose optimization.");
272 Value source = createNdOp.getSource();
273 auto optionalLaneData = getMaybeLaneData(tdescTy);
274 assert(optionalLaneData &&
"Expected 2D lane data");
275 auto laneData = optionalLaneData.value();
276 int64_t innerLaneData = laneData[1];
277 auto memrefType = dyn_cast<MemRefType>(source.
getType());
279 SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
280 modifiedShape.back() = divideByConstant(
281 rewriter, createNdOp.getLoc(),
282 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
285 assert(strides.size() >= 2 &&
286 "Expected at least 2 strides for CreateNdDescOp");
287 SmallVector<OpFoldResult> modifiedStrides(strides);
288 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
289 rewriter, createNdOp.getLoc(),
290 convertToValue(rewriter, createNdOp.getLoc(),
291 modifiedStrides[modifiedStrides.size() - 2]),
296 if (memrefType && memrefType.hasStaticShape()) {
297 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
298 rewriter, createNdOp.getLoc(), source);
299 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
300 rewriter.getI64Type(),
301 extractOp.getResult())
305 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
306 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
308 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
317class XeGPULoadNdDescOpPattern final
318 :
public OpConversionPattern<xegpu::LoadNdOp> {
320 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
322 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter)
const override {
324 auto origTensorDescType = loadNdOp.getTensorDescType();
326 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
327 if (adaptorType == origTensorDescType)
330 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
331 int64_t innerLaneData = laneData[1];
332 auto offsets = loadNdOp.getMixedOffsets();
334 return rewriter.notifyMatchFailure(loadNdOp,
335 "Expecting offsets in LoadNd");
336 SmallVector<OpFoldResult> modifiedOffsets(offsets);
337 modifiedOffsets.back() = divideByConstant(
338 rewriter, loadNdOp.getLoc(),
339 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
343 SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
345 origDataShape.back() /= innerLaneData;
347 SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
348 VectorType origVectorType =
349 VectorType::get(origDataShape, adaptorType.getElementType());
352 if (origTensorDescType.getArrayLength() > 1) {
353 SmallVector<Value> arraySlices;
354 for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
355 Value slice = arith::ConstantOp::create(
356 rewriter, loadNdOp->getLoc(), origVectorType,
357 rewriter.getZeroAttr(origVectorType));
359 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
360 modifiedOffsets.back());
361 modifiedOffsets.back() =
362 arith::AddIOp::create(
363 rewriter, loadNdOp->getLoc(), offsetY,
365 i * origDataShape[1])
368 slice = generateLoads(
373 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
374 origTensorDescType.getElementType());
375 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
379 origTensorDescType.getLayoutAttr());
380 arraySlices.push_back(bitCastOp.getResult());
382 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
385 data = arith::ConstantOp::create(
386 rewriter, loadNdOp->getLoc(),
387 VectorType::get(origDataShape, adaptorType.getElementType()),
388 rewriter.getZeroAttr(origVectorType));
389 data = generateLoads(
393 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
394 loadNdOp.getType(), data);
397 origTensorDescType.getLayoutAttr());
398 rewriter.replaceOp(loadNdOp, bitCastOp);
407class VectorExtractOpPattern final
408 :
public OpConversionPattern<vector::ExtractOp> {
410 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
412 matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
413 ConversionPatternRewriter &rewriter)
const override {
415 if (adaptor.getSource().size() == 1)
417 auto mixedPos = extractOp.getMixedPosition();
418 if (mixedPos.size() != 1)
423 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
430class MultiRed2dOpPattern
431 :
public OpConversionPattern<vector::MultiDimReductionOp> {
432 using OpConversionPattern::OpConversionPattern;
434 matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
435 ConversionPatternRewriter &rewriter)
const override {
436 auto sourceVecType = reductionOp.getSourceVectorType();
437 if (reductionOp.getReductionDims().size() != 2)
438 return rewriter.notifyMatchFailure(reductionOp,
"Expected 2D reduction");
441 auto dims = llvm::to_vector(reductionOp.getReductionDims());
442 auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
444 if (intraLaneDim == -1 || crossLaneDim == -1) {
445 intraLaneDim = dims[0];
446 crossLaneDim = dims[1];
448 auto loc = reductionOp.getLoc();
449 auto acc = reductionOp.getAcc();
461 xegpu::DistributeLayoutAttr postDecompLayout;
464 xegpu::DistributeLayoutAttr srcLayoutForCvt;
465 if (
auto resSlice = dyn_cast_if_present<xegpu::SliceAttr>(resLayout))
466 srcLayoutForCvt = resSlice.getParent();
467 if (!srcLayoutForCvt)
470 if (srcLayoutForCvt) {
476 MLIRContext *ctx = reductionOp.getContext();
477 int64_t adjCrossLaneDim =
478 crossLaneDim > intraLaneDim ? crossLaneDim - 1 : crossLaneDim;
479 auto intermediateLayout = xegpu::SliceAttr::get(
481 postDecompLayout = xegpu::SliceAttr::get(
482 ctx, intermediateLayout,
487 SmallVector<int64_t> accShape(sourceVecType.getShape());
488 accShape.erase(accShape.begin() + intraLaneDim);
489 Type eTy = sourceVecType.getElementType();
491 rewriter, loc, VectorType::get(accShape, eTy), reductionOp.getKind());
493 Value intraLaneReduced = vector::MultiDimReductionOp::create(
494 rewriter, loc, reductionOp.getKind(), reductionOp.getSource(),
495 constNeutralVal, ArrayRef<int64_t>(intraLaneDim));
498 if (crossLaneDim > intraLaneDim)
500 Value crossLaneReduced = vector::MultiDimReductionOp::create(
501 rewriter, loc, reductionOp.getKind(), intraLaneReduced, acc,
502 ArrayRef<int64_t>(crossLaneDim));
503 assert(crossLaneReduced.
getType() == reductionOp.getResult().getType() &&
507 if (resLayout && postDecompLayout) {
513 auto bridgeOp = xegpu::ConvertLayoutOp::create(
514 rewriter, loc, crossLaneReduced.
getType(), crossLaneReduced,
515 postDecompLayout, resLayout);
524 std::pair<int64_t, int64_t>
525 getReductionDimOrder(ArrayRef<int64_t> reductionDims,
526 xegpu::DistributeLayoutAttr layout)
const {
527 assert(layout.isForSubgroup() &&
"Must know the lane layout");
528 assert(reductionDims.size() == 2 &&
"Expected 2D reduction");
529 int64_t intra, cross = -1;
530 xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
531 if (
auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
533 dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
535 SmallVector<int64_t> laneLayout = layoutAttr.getEffectiveLaneLayoutAsInt();
537 assert(laneLayout.size() &&
"Expected a non-empty layout");
539 for (
auto dim : reductionDims) {
540 if (laneLayout[dim] == 1)
545 return {intra, cross};
553 patterns.
add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
554 VectorExtractOpPattern, MultiRed2dOpPattern>(
560struct XeGPUPeepHoleOptimizerPass final
562 XeGPUPeepHoleOptimizerPass> {
563 void runOnOperation()
override {
571 bool isTargetSupported =
false;
572 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
574 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg"))
575 isTargetSupported =
true;
578 if (!isTargetSupported) {
579 DBGS() <<
"XeGPUPeepHoleOptimizerPass only supports PVC, BMG targets."
587 RewritePatternSet arrayLenPatterns(&context);
590 std::move(arrayLenPatterns)))) {
591 DBGS() <<
"Array length optimization patterns failed.\n";
592 return signalPassFailure();
598 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
599 [&](xegpu::CreateNdDescOp createNdOp) {
600 return !canBeOptimizedForTranspose(createNdOp.getType());
602 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
603 [&](xegpu::LoadNdOp loadNdOp) {
604 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
609 target.addDynamicallyLegalOp<vector::ExtractOp>(
610 [&](vector::ExtractOp extractOp) {
612 dyn_cast<OpResult>(extractOp.getResult()));
615 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
616 auto laneData = layout.getEffectiveLaneDataAsInt();
617 return !canBeOptimizedForTranspose(laneLayout, laneData);
620 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
621 [=](Operation *op) ->
bool {
623 if (!layout || !layout.isForSubgroup())
625 if (
auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
626 return reductionOp.getReductionDims().size() != 2;
630 converter.addConversion([](Type type) {
return type; });
632 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
633 vector::VectorDialect>();
636 target.addLegalOp<xegpu::ConvertLayoutOp>();
640 if (
failed(applyPartialConversion(getOperation(),
target,
641 std::move(patterns)))) {
642 DBGS() <<
"Optimize block loads pass failed.\n";
643 return signalPassFailure();
648 RewritePatternSet emptyPatterns(ctx);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUArrayLengthOptimizationPatterns(RewritePatternSet &patterns)
Appends patterns for array length optimization into patterns.
void removeTemporaryLayoutAttrs(Operation *op)
Removes the temporary layout attributes for each OpOperand and OpResult of the given operation.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUPeepHoleOptimizerPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
const Instruction * getInstruction(InstructionKind instKind) const