27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallVector.h"
33#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
34#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38#define DEBUG_TYPE "xegpu-optimize-peephole"
39#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
46static std::optional<SmallVector<int64_t>>
47getMaybeLaneData(xegpu::TensorDescType tdescType) {
48 auto layout = tdescType.getLayoutAttr();
51 auto laneData = layout.getEffectiveLaneDataAsInt();
52 if (laneData.size() != 2)
58static std::optional<SmallVector<int64_t>>
59getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
60 auto layout = tdescType.getLayoutAttr();
63 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
64 if (laneLayout.size() != 2)
82 if (laneLayout.size() != 2 || laneData.size() != 2)
84 if (laneLayout[0] == 1 || laneLayout[1] != 1)
86 if (laneData[0] != 1 || laneData[1] == 1)
93static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
95 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
96 if (elementTyBitwidth >= 32)
98 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
99 auto maybeLaneData = getMaybeLaneData(tdescType);
100 if (!maybeLaneData || !maybeLaneLayout)
102 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
107static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
108 const uArch *targetuArch) {
109 if (!canBeOptimizedForTranspose(tdescType))
111 auto laneData = getMaybeLaneData(tdescType)
113 int64_t innerLaneData = laneData[1];
114 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
119 requiredShape.back() =
120 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
121 int newBitWidth = elementTyBitwidth * innerLaneData;
122 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
125 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
126 targetuArch->
getInstruction(InstructionKind::Subgroup2DBlockLoad));
127 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
128 newElemTy,
false,
true);
132 auto [widths, heights, counts] = maybeHWParams.value();
134 if (counts.size() != 1 || counts[0] != 1)
136 int arrayLen = counts[0];
137 int supportedHeight =
142 if (supportedHeight == -1 || supportedWidth == -1)
146 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
147 tdescType.getContext(),
148 tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
150 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
151 tdescType.getBoundaryCheck(),
152 tdescType.getMemorySpace(), newLayout);
156static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
161 return llvm::cast<Value>(ofr);
165static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
168 if (llvm::isPowerOf2_64(constant)) {
169 int64_t shiftAmount = llvm::Log2_64(constant);
170 return arith::ShRUIOp::create(
178 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
184static Value generateLoads(ConversionPatternRewriter &rewriter,
188 xegpu::LoadNdOp origLoadOp) {
190 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
191 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
192 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
199 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
200 for (
int64_t w = 0; w < shapeRatio[1]; ++w) {
201 int64_t localOffsetDim0 = h * supportedShape[0];
202 int64_t localOffsetDim1 = w * supportedShape[1];
203 Value loadOffsetX = arith::AddIOp::create(
204 rewriter, loc, offsetDim0,
207 Value loadOffsetY = arith::AddIOp::create(
208 rewriter, loc, offsetDim1,
211 auto loadOp = xegpu::LoadNdOp::create(
213 VectorType::get(supportedShape, data.getType().getElementType()),
215 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
216 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
217 origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
219 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
220 loadOp.setAnchorLayout(layoutAttr);
222 auto insertOp = vector::InsertStridedSliceOp::create(
223 rewriter, loc, loadOp.getResult(), data,
228 data = insertOp.getResult();
238class XeGPUCreateNdDescOpPattern final
239 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
241 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
243 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
244 ConversionPatternRewriter &rewriter)
const override {
245 auto tdescTy = createNdOp.getType();
250 chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg") &&
251 "Expecting target chip to be pvc or bmg for transpose optimization.");
254 auto convertType = tryOptimize(tdescTy, targetuArch);
255 if (convertType == tdescTy)
257 auto strides = createNdOp.getMixedStrides();
260 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
261 return rewriter.notifyMatchFailure(
262 createNdOp,
"Expecting row-major memref for transpose optimization.");
263 Value source = createNdOp.getSource();
264 auto optionalLaneData = getMaybeLaneData(tdescTy);
265 assert(optionalLaneData &&
"Expected 2D lane data");
266 auto laneData = optionalLaneData.value();
267 int64_t innerLaneData = laneData[1];
268 auto memrefType = dyn_cast<MemRefType>(source.
getType());
271 modifiedShape.back() = divideByConstant(
272 rewriter, createNdOp.getLoc(),
273 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
276 assert(strides.size() >= 2 &&
277 "Expected at least 2 strides for CreateNdDescOp");
279 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
280 rewriter, createNdOp.getLoc(),
281 convertToValue(rewriter, createNdOp.getLoc(),
282 modifiedStrides[modifiedStrides.size() - 2]),
287 if (memrefType && memrefType.hasStaticShape()) {
288 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
289 rewriter, createNdOp.getLoc(), source);
290 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
291 rewriter.getI64Type(),
292 extractOp.getResult())
296 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
297 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
299 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
308class XeGPULoadNdDescOpPattern final
309 :
public OpConversionPattern<xegpu::LoadNdOp> {
311 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
313 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter)
const override {
315 auto origTensorDescType = loadNdOp.getTensorDescType();
317 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
318 if (adaptorType == origTensorDescType)
321 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
322 int64_t innerLaneData = laneData[1];
323 auto offsets = loadNdOp.getMixedOffsets();
325 return rewriter.notifyMatchFailure(loadNdOp,
326 "Expecting offsets in LoadNd");
328 modifiedOffsets.back() = divideByConstant(
329 rewriter, loadNdOp.getLoc(),
330 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
336 origDataShape.back() /= innerLaneData;
339 VectorType origVectorType =
340 VectorType::get(origDataShape, adaptorType.getElementType());
343 if (origTensorDescType.getArrayLength() > 1) {
345 for (
int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
346 Value slice = arith::ConstantOp::create(
347 rewriter, loadNdOp->getLoc(), origVectorType,
348 rewriter.getZeroAttr(origVectorType));
350 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
351 modifiedOffsets.back());
352 modifiedOffsets.back() =
353 arith::AddIOp::create(
354 rewriter, loadNdOp->getLoc(), offsetY,
356 i * origDataShape[1])
359 slice = generateLoads(
364 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
365 origTensorDescType.getElementType());
366 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
370 origTensorDescType.getLayoutAttr());
371 arraySlices.push_back(bitCastOp.getResult());
373 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
376 data = arith::ConstantOp::create(
377 rewriter, loadNdOp->getLoc(),
378 VectorType::get(origDataShape, adaptorType.getElementType()),
379 rewriter.getZeroAttr(origVectorType));
380 data = generateLoads(
384 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
385 loadNdOp.getType(), data);
388 origTensorDescType.getLayoutAttr());
389 rewriter.replaceOp(loadNdOp, bitCastOp);
398class VectorExtractOpPattern final
399 :
public OpConversionPattern<vector::ExtractOp> {
401 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
404 ConversionPatternRewriter &rewriter)
const override {
406 if (adaptor.getSource().size() == 1)
408 auto mixedPos = extractOp.getMixedPosition();
409 if (mixedPos.size() != 1)
414 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
421class MultiRed2dOpPattern
422 :
public OpConversionPattern<vector::MultiDimReductionOp> {
423 using OpConversionPattern::OpConversionPattern;
425 matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
426 ConversionPatternRewriter &rewriter)
const override {
427 auto sourceVecType = reductionOp.getSourceVectorType();
428 if (reductionOp.getReductionDims().size() != 2 ||
429 sourceVecType.getRank() != 2)
430 return rewriter.notifyMatchFailure(
431 reductionOp,
"Expected 2D multi reduction of a 2D source");
434 auto dims = llvm::to_vector(reductionOp.getReductionDims());
435 auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
437 if (intraLaneDim == -1 || crossLaneDim == -1) {
438 intraLaneDim = dims[0];
439 crossLaneDim = dims[1];
441 auto loc = reductionOp.getLoc();
442 auto acc = reductionOp.getAcc();
445 auto resSliceLayoutAttr = cast<xegpu::SliceAttr>(resLayout);
447 auto intraLaneRedResLayout = resSliceLayoutAttr.dropSliceDims(
dropDims);
450 accShape.erase(accShape.begin() + intraLaneDim);
452 acc = vector::BroadcastOp::create(
454 VectorType::get(accShape, sourceVecType.getElementType()),
acc);
456 llvm::dyn_cast<OpResult>(
acc),
457 cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
459 Value intraLaneReduced = vector::MultiDimReductionOp::create(
460 rewriter, loc, reductionOp.getKind(), reductionOp.getSource(),
acc,
463 llvm::dyn_cast<OpResult>(intraLaneReduced),
464 cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
466 Value crossLaneReduced = vector::ReductionOp::create(
467 rewriter, loc, reductionOp.getKind(), intraLaneReduced,
nullptr);
469 llvm::dyn_cast<OpResult>(crossLaneReduced),
470 cast<xegpu::DistributeLayoutAttr>(resLayout));
471 assert(crossLaneReduced.
getType() == reductionOp.getResult().getType() &&
473 rewriter.replaceOp(reductionOp, crossLaneReduced);
478 std::pair<int64_t, int64_t>
480 xegpu::DistributeLayoutAttr layout)
const {
481 assert(layout.isForSubgroup() &&
"Must know the lane layout");
482 assert(reductionDims.size() == 2 &&
"Expected 2D reduction");
484 xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
485 if (
auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
487 dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
491 assert(laneLayout.size() &&
"Expected a non-empty layout");
493 for (
auto dim : reductionDims) {
494 if (laneLayout[dim] == 1)
499 return {intra, cross};
507 patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
508 VectorExtractOpPattern, MultiRed2dOpPattern>(
514struct XeGPUPeepHoleOptimizerPass final
515 :
public xegpu::impl::XeGPUPeepHoleOptimizerBase<
516 XeGPUPeepHoleOptimizerPass> {
517 void runOnOperation()
override {
525 bool isTargetSupported =
false;
526 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
528 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg"))
529 isTargetSupported =
true;
532 if (!isTargetSupported) {
533 DBGS() <<
"XeGPUPeepHoleOptimizerPass only supports PVC and BMG targets."
540 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
541 [&](xegpu::CreateNdDescOp createNdOp) {
542 return !canBeOptimizedForTranspose(createNdOp.getType());
544 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
545 [&](xegpu::LoadNdOp loadNdOp) {
546 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
551 target.addDynamicallyLegalOp<vector::ExtractOp>(
552 [&](vector::ExtractOp extractOp) {
554 dyn_cast<OpResult>(extractOp.getResult()));
557 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
558 auto laneData = layout.getEffectiveLaneDataAsInt();
559 return !canBeOptimizedForTranspose(laneLayout, laneData);
562 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
563 [=](Operation *op) ->
bool {
565 if (!layout || !layout.isForSubgroup())
567 if (
auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
568 return reductionOp.getReductionDims().size() != 2;
572 converter.addConversion([](Type type) {
return type; });
574 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
575 vector::VectorDialect>();
579 if (
failed(applyPartialConversion(getOperation(),
target,
581 DBGS() <<
"Optimize block loads pass failed.\n";
582 return signalPassFailure();
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUPeepHoleOptimizerPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
const Instruction * getInstruction(InstructionKind instKind) const