29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
35#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
36#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
40#define DEBUG_TYPE "xegpu-optimize-peephole"
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
48static std::optional<SmallVector<int64_t>>
49getMaybeLaneData(xegpu::TensorDescType tdescType) {
50 auto layout = tdescType.getLayoutAttr();
53 auto laneData = layout.getEffectiveLaneDataAsInt();
54 if (laneData.size() != 2)
60static std::optional<SmallVector<int64_t>>
61getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
62 auto layout = tdescType.getLayoutAttr();
65 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
66 if (laneLayout.size() != 2)
84 if (laneLayout.size() != 2 || laneData.size() != 2)
86 if (laneLayout[0] == 1 || laneLayout[1] != 1)
88 if (laneData[0] != 1 || laneData[1] == 1)
95static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
97 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
98 if (elementTyBitwidth >= 32)
100 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
101 auto maybeLaneData = getMaybeLaneData(tdescType);
102 if (!maybeLaneData || !maybeLaneLayout)
104 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
109static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
110 const uArch *targetuArch) {
111 if (!canBeOptimizedForTranspose(tdescType))
113 auto laneData = getMaybeLaneData(tdescType)
115 int64_t innerLaneData = laneData[1];
116 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
121 requiredShape.back() =
122 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
123 int newBitWidth = elementTyBitwidth * innerLaneData;
124 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
127 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
128 targetuArch->
getInstruction(InstructionKind::Subgroup2DBlockLoad));
129 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
130 newElemTy,
false,
true);
134 auto [widths, heights, counts] = maybeHWParams.value();
136 if (counts.size() != 1 || counts[0] != 1)
138 int arrayLen = counts[0];
139 int supportedHeight =
144 if (supportedHeight == -1 || supportedWidth == -1)
148 auto ctx = tdescType.getContext();
149 auto origLayout = tdescType.getLayoutAttr();
150 auto laneLayoutI64 = origLayout.getEffectiveLaneLayoutAsInt();
152 laneLayoutI64.end());
154 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
157 origLayout.getOrder());
160 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
161 tdescType.getBoundaryCheck(),
162 tdescType.getMemorySpace(), newLayout);
166static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
171 return llvm::cast<Value>(ofr);
175static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
178 if (llvm::isPowerOf2_64(constant)) {
179 int64_t shiftAmount = llvm::Log2_64(constant);
180 return arith::ShRUIOp::create(
188 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
194static Value generateLoads(ConversionPatternRewriter &rewriter,
198 xegpu::LoadNdOp origLoadOp) {
200 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
201 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
202 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
209 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
210 for (
int64_t w = 0; w < shapeRatio[1]; ++w) {
211 int64_t localOffsetDim0 = h * supportedShape[0];
212 int64_t localOffsetDim1 = w * supportedShape[1];
213 Value loadOffsetX = arith::AddIOp::create(
214 rewriter, loc, offsetDim0,
217 Value loadOffsetY = arith::AddIOp::create(
218 rewriter, loc, offsetDim1,
221 auto loadOp = xegpu::LoadNdOp::create(
223 VectorType::get(supportedShape, data.getType().getElementType()),
225 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
226 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
227 origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
229 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
230 loadOp.setAnchorLayout(layoutAttr);
232 auto insertOp = vector::InsertStridedSliceOp::create(
233 rewriter, loc, loadOp.getResult(), data,
238 data = insertOp.getResult();
248class XeGPUCreateNdDescOpPattern final
249 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
251 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
253 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override {
255 auto tdescTy = createNdOp.getType();
260 chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg") &&
261 "Expecting target chip to be pvc or bmg for transpose optimization.");
264 auto convertType = tryOptimize(tdescTy, targetuArch);
265 if (convertType == tdescTy)
267 auto strides = createNdOp.getMixedStrides();
270 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
271 return rewriter.notifyMatchFailure(
272 createNdOp,
"Expecting row-major memref for transpose optimization.");
273 Value source = createNdOp.getSource();
274 auto optionalLaneData = getMaybeLaneData(tdescTy);
275 assert(optionalLaneData &&
"Expected 2D lane data");
276 auto laneData = optionalLaneData.value();
277 int64_t innerLaneData = laneData[1];
278 auto memrefType = dyn_cast<MemRefType>(source.
getType());
281 modifiedShape.back() = divideByConstant(
282 rewriter, createNdOp.getLoc(),
283 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
286 assert(strides.size() >= 2 &&
287 "Expected at least 2 strides for CreateNdDescOp");
289 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
290 rewriter, createNdOp.getLoc(),
291 convertToValue(rewriter, createNdOp.getLoc(),
292 modifiedStrides[modifiedStrides.size() - 2]),
297 if (memrefType && memrefType.hasStaticShape()) {
298 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
299 rewriter, createNdOp.getLoc(), source);
300 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
301 rewriter.getI64Type(),
302 extractOp.getResult())
306 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
307 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
309 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
318class XeGPULoadNdDescOpPattern final
319 :
public OpConversionPattern<xegpu::LoadNdOp> {
321 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
323 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
324 ConversionPatternRewriter &rewriter)
const override {
325 auto origTensorDescType = loadNdOp.getTensorDescType();
327 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
328 if (adaptorType == origTensorDescType)
331 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
332 int64_t innerLaneData = laneData[1];
333 auto offsets = loadNdOp.getMixedOffsets();
335 return rewriter.notifyMatchFailure(loadNdOp,
336 "Expecting offsets in LoadNd");
338 modifiedOffsets.back() = divideByConstant(
339 rewriter, loadNdOp.getLoc(),
340 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
346 origDataShape.back() /= innerLaneData;
349 VectorType origVectorType =
350 VectorType::get(origDataShape, adaptorType.getElementType());
353 if (origTensorDescType.getArrayLength() > 1) {
355 for (
int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
356 Value slice = arith::ConstantOp::create(
357 rewriter, loadNdOp->getLoc(), origVectorType,
358 rewriter.getZeroAttr(origVectorType));
360 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
361 modifiedOffsets.back());
362 modifiedOffsets.back() =
363 arith::AddIOp::create(
364 rewriter, loadNdOp->getLoc(), offsetY,
366 i * origDataShape[1])
369 slice = generateLoads(
374 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
375 origTensorDescType.getElementType());
376 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
380 origTensorDescType.getLayoutAttr());
381 arraySlices.push_back(bitCastOp.getResult());
383 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
386 data = arith::ConstantOp::create(
387 rewriter, loadNdOp->getLoc(),
388 VectorType::get(origDataShape, adaptorType.getElementType()),
389 rewriter.getZeroAttr(origVectorType));
390 data = generateLoads(
394 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
395 loadNdOp.getType(), data);
398 origTensorDescType.getLayoutAttr());
399 rewriter.replaceOp(loadNdOp, bitCastOp);
408class VectorExtractOpPattern final
409 :
public OpConversionPattern<vector::ExtractOp> {
411 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
414 ConversionPatternRewriter &rewriter)
const override {
416 if (adaptor.getSource().size() == 1)
418 auto mixedPos = extractOp.getMixedPosition();
419 if (mixedPos.size() != 1)
424 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
431class MultiRed2dOpPattern
432 :
public OpConversionPattern<vector::MultiDimReductionOp> {
433 using OpConversionPattern::OpConversionPattern;
435 matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
436 ConversionPatternRewriter &rewriter)
const override {
437 auto sourceVecType = reductionOp.getSourceVectorType();
438 if (reductionOp.getReductionDims().size() != 2)
439 return rewriter.notifyMatchFailure(reductionOp,
"Expected 2D reduction");
442 auto dims = llvm::to_vector(reductionOp.getReductionDims());
443 auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
445 if (intraLaneDim == -1 || crossLaneDim == -1) {
446 intraLaneDim = dims[0];
447 crossLaneDim = dims[1];
449 auto loc = reductionOp.getLoc();
450 auto acc = reductionOp.getAcc();
455 Type resultType = reductionOp.getResult().getType();
457 for (
auto &use : reductionOp.getResult().getUses()) {
458 if (
auto convertLayoutOp =
459 llvm::dyn_cast<xegpu::ConvertLayoutOp>(use.getOwner())) {
460 rewriter.replaceOp(convertLayoutOp, reductionOp.getResult());
467 accShape.erase(accShape.begin() + intraLaneDim);
468 Type eTy = sourceVecType.getElementType();
470 rewriter, loc, VectorType::get(accShape, eTy), reductionOp.getKind());
472 Value intraLaneReduced = vector::MultiDimReductionOp::create(
473 rewriter, loc, reductionOp.getKind(), reductionOp.getSource(),
477 if (crossLaneDim > intraLaneDim)
479 Value crossLaneReduced = vector::MultiDimReductionOp::create(
480 rewriter, loc, reductionOp.getKind(), intraLaneReduced,
acc,
482 assert(crossLaneReduced.
getType() == reductionOp.getResult().getType() &&
484 rewriter.replaceOp(reductionOp, crossLaneReduced);
489 std::pair<int64_t, int64_t>
491 xegpu::DistributeLayoutAttr layout)
const {
492 assert(layout.isForSubgroup() &&
"Must know the lane layout");
493 assert(reductionDims.size() == 2 &&
"Expected 2D reduction");
495 xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
496 if (
auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
498 dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
502 assert(laneLayout.size() &&
"Expected a non-empty layout");
504 for (
auto dim : reductionDims) {
505 if (laneLayout[dim] == 1)
510 return {intra, cross};
518 patterns.
add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
519 VectorExtractOpPattern, MultiRed2dOpPattern>(
525struct XeGPUPeepHoleOptimizerPass final
526 :
public xegpu::impl::XeGPUPeepHoleOptimizerBase<
527 XeGPUPeepHoleOptimizerPass> {
528 void runOnOperation()
override {
536 bool isTargetSupported =
false;
537 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
539 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg"))
540 isTargetSupported =
true;
543 if (!isTargetSupported) {
544 DBGS() <<
"XeGPUPeepHoleOptimizerPass only supports PVC and BMG targets."
551 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
552 [&](xegpu::CreateNdDescOp createNdOp) {
553 return !canBeOptimizedForTranspose(createNdOp.getType());
555 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
556 [&](xegpu::LoadNdOp loadNdOp) {
557 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
562 target.addDynamicallyLegalOp<vector::ExtractOp>(
563 [&](vector::ExtractOp extractOp) {
565 dyn_cast<OpResult>(extractOp.getResult()));
568 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
569 auto laneData = layout.getEffectiveLaneDataAsInt();
570 return !canBeOptimizedForTranspose(laneLayout, laneData);
573 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
574 [=](Operation *op) ->
bool {
576 if (!layout || !layout.isForSubgroup())
578 if (
auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
579 return reductionOp.getReductionDims().size() != 2;
583 converter.addConversion([](Type type) {
return type; });
585 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
586 vector::VectorDialect>();
590 if (
failed(applyPartialConversion(getOperation(),
target,
591 std::move(patterns)))) {
592 DBGS() <<
"Optimize block loads pass failed.\n";
593 return signalPassFailure();
598 RewritePatternSet emptyPatterns(ctx);
602 getOperation()->walk([](Operation *op) {
603 SmallVector<StringAttr> attrsToRemove;
605 if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
606 attrsToRemove.push_back(namedAttr.getName());
608 for (
auto attrName : attrsToRemove)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
auto getDiscardableAttrs()
Return a range of all of discardable attributes on this operation.
Attribute removeDiscardableAttr(StringAttr name)
Remove the discardable attribute with the specified name if it exists.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUPeepHoleOptimizerPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
const Instruction * getInstruction(InstructionKind instKind) const