28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallVector.h"
34#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
35#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
39#define DEBUG_TYPE "xegpu-optimize-peephole"
40#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
47static std::optional<SmallVector<int64_t>>
48getMaybeLaneData(xegpu::TensorDescType tdescType) {
49 auto layout = tdescType.getLayoutAttr();
52 auto laneData = layout.getEffectiveLaneDataAsInt();
53 if (laneData.size() != 2)
59static std::optional<SmallVector<int64_t>>
60getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
61 auto layout = tdescType.getLayoutAttr();
64 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
65 if (laneLayout.size() != 2)
83 if (laneLayout.size() != 2 || laneData.size() != 2)
85 if (laneLayout[0] == 1 || laneLayout[1] != 1)
87 if (laneData[0] != 1 || laneData[1] == 1)
94static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
96 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
97 if (elementTyBitwidth >= 32)
99 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
100 auto maybeLaneData = getMaybeLaneData(tdescType);
101 if (!maybeLaneData || !maybeLaneLayout)
103 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
108static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
109 const uArch *targetuArch) {
110 if (!canBeOptimizedForTranspose(tdescType))
112 auto laneData = getMaybeLaneData(tdescType)
114 int64_t innerLaneData = laneData[1];
115 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
120 requiredShape.back() =
121 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
122 int newBitWidth = elementTyBitwidth * innerLaneData;
123 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
126 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
127 targetuArch->
getInstruction(InstructionKind::Subgroup2DBlockLoad));
128 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
129 newElemTy,
false,
true);
133 auto [widths, heights, counts] = maybeHWParams.value();
135 if (counts.size() != 1 || counts[0] != 1)
137 int arrayLen = counts[0];
138 int supportedHeight =
143 if (supportedHeight == -1 || supportedWidth == -1)
147 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
148 tdescType.getContext(),
149 tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
151 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
152 tdescType.getBoundaryCheck(),
153 tdescType.getMemorySpace(), newLayout);
157static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
162 return llvm::cast<Value>(ofr);
166static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
169 if (llvm::isPowerOf2_64(constant)) {
170 int64_t shiftAmount = llvm::Log2_64(constant);
171 return arith::ShRUIOp::create(
179 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
185static Value generateLoads(ConversionPatternRewriter &rewriter,
189 xegpu::LoadNdOp origLoadOp) {
191 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
192 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
193 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
200 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
202 int64_t localOffsetDim0 = h * supportedShape[0];
203 int64_t localOffsetDim1 = w * supportedShape[1];
204 Value loadOffsetX = arith::AddIOp::create(
205 rewriter, loc, offsetDim0,
208 Value loadOffsetY = arith::AddIOp::create(
209 rewriter, loc, offsetDim1,
212 auto loadOp = xegpu::LoadNdOp::create(
214 VectorType::get(supportedShape, data.getType().getElementType()),
216 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
217 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
218 origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
220 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
221 loadOp.setAnchorLayout(layoutAttr);
223 auto insertOp = vector::InsertStridedSliceOp::create(
224 rewriter, loc, loadOp.getResult(), data,
229 data = insertOp.getResult();
239class XeGPUCreateNdDescOpPattern final
240 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
242 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
244 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter)
const override {
246 auto tdescTy = createNdOp.getType();
251 chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg") &&
252 "Expecting target chip to be pvc or bmg for transpose optimization.");
255 auto convertType = tryOptimize(tdescTy, targetuArch);
256 if (convertType == tdescTy)
258 auto strides = createNdOp.getMixedStrides();
261 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
262 return rewriter.notifyMatchFailure(
263 createNdOp,
"Expecting row-major memref for transpose optimization.");
264 Value source = createNdOp.getSource();
265 auto optionalLaneData = getMaybeLaneData(tdescTy);
266 assert(optionalLaneData &&
"Expected 2D lane data");
267 auto laneData = optionalLaneData.value();
268 int64_t innerLaneData = laneData[1];
269 auto memrefType = dyn_cast<MemRefType>(source.
getType());
272 modifiedShape.back() = divideByConstant(
273 rewriter, createNdOp.getLoc(),
274 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
277 assert(strides.size() >= 2 &&
278 "Expected at least 2 strides for CreateNdDescOp");
280 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
281 rewriter, createNdOp.getLoc(),
282 convertToValue(rewriter, createNdOp.getLoc(),
283 modifiedStrides[modifiedStrides.size() - 2]),
288 if (memrefType && memrefType.hasStaticShape()) {
289 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
290 rewriter, createNdOp.getLoc(), source);
291 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
292 rewriter.getI64Type(),
293 extractOp.getResult())
297 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
298 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
300 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
309class XeGPULoadNdDescOpPattern final
310 :
public OpConversionPattern<xegpu::LoadNdOp> {
312 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
314 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter)
const override {
316 auto origTensorDescType = loadNdOp.getTensorDescType();
318 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
319 if (adaptorType == origTensorDescType)
322 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
323 int64_t innerLaneData = laneData[1];
324 auto offsets = loadNdOp.getMixedOffsets();
326 return rewriter.notifyMatchFailure(loadNdOp,
327 "Expecting offsets in LoadNd");
329 modifiedOffsets.back() = divideByConstant(
330 rewriter, loadNdOp.getLoc(),
331 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
337 origDataShape.back() /= innerLaneData;
340 VectorType origVectorType =
341 VectorType::get(origDataShape, adaptorType.getElementType());
344 if (origTensorDescType.getArrayLength() > 1) {
346 for (
int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
347 Value slice = arith::ConstantOp::create(
348 rewriter, loadNdOp->getLoc(), origVectorType,
349 rewriter.getZeroAttr(origVectorType));
351 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
352 modifiedOffsets.back());
353 modifiedOffsets.back() =
354 arith::AddIOp::create(
355 rewriter, loadNdOp->getLoc(), offsetY,
357 i * origDataShape[1])
360 slice = generateLoads(
365 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
366 origTensorDescType.getElementType());
367 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
371 origTensorDescType.getLayoutAttr());
372 arraySlices.push_back(bitCastOp.getResult());
374 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
377 data = arith::ConstantOp::create(
378 rewriter, loadNdOp->getLoc(),
379 VectorType::get(origDataShape, adaptorType.getElementType()),
380 rewriter.getZeroAttr(origVectorType));
381 data = generateLoads(
385 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
386 loadNdOp.getType(), data);
389 origTensorDescType.getLayoutAttr());
390 rewriter.replaceOp(loadNdOp, bitCastOp);
399class VectorExtractOpPattern final
400 :
public OpConversionPattern<vector::ExtractOp> {
402 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
404 matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override {
407 if (adaptor.getSource().size() == 1)
409 auto mixedPos = extractOp.getMixedPosition();
410 if (mixedPos.size() != 1)
415 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
422class MultiRed2dOpPattern
423 :
public OpConversionPattern<vector::MultiDimReductionOp> {
424 using OpConversionPattern::OpConversionPattern;
426 matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
427 ConversionPatternRewriter &rewriter)
const override {
428 auto sourceVecType = reductionOp.getSourceVectorType();
429 if (reductionOp.getReductionDims().size() != 2 ||
430 sourceVecType.getRank() != 2)
431 return rewriter.notifyMatchFailure(
432 reductionOp,
"Expected 2D multi reduction of a 2D source");
435 auto dims = llvm::to_vector(reductionOp.getReductionDims());
436 auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
438 if (intraLaneDim == -1 || crossLaneDim == -1) {
439 intraLaneDim = dims[0];
440 crossLaneDim = dims[1];
442 auto loc = reductionOp.getLoc();
443 auto acc = reductionOp.getAcc();
446 auto resSliceLayoutAttr = cast<xegpu::SliceAttr>(resLayout);
447 SmallVector<int64_t>
dropDims{crossLaneDim};
448 auto intraLaneRedResLayout = resSliceLayoutAttr.dropSliceDims(
dropDims);
450 SmallVector<int64_t> accShape(sourceVecType.getShape());
451 accShape.erase(accShape.begin() + intraLaneDim);
453 acc = vector::BroadcastOp::create(
455 VectorType::get(accShape, sourceVecType.getElementType()), acc);
457 llvm::dyn_cast<OpResult>(acc),
458 cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
460 Value intraLaneReduced = vector::MultiDimReductionOp::create(
461 rewriter, loc, reductionOp.getKind(), reductionOp.getSource(), acc,
462 ArrayRef<int64_t>(intraLaneDim));
464 llvm::dyn_cast<OpResult>(intraLaneReduced),
465 cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
467 Value crossLaneReduced = vector::ReductionOp::create(
468 rewriter, loc, reductionOp.getKind(), intraLaneReduced,
nullptr);
470 llvm::dyn_cast<OpResult>(crossLaneReduced),
471 cast<xegpu::DistributeLayoutAttr>(resLayout));
472 assert(crossLaneReduced.
getType() == reductionOp.getResult().getType() &&
474 rewriter.replaceOp(reductionOp, crossLaneReduced);
479 std::pair<int64_t, int64_t>
480 getReductionDimOrder(ArrayRef<int64_t> reductionDims,
481 xegpu::DistributeLayoutAttr layout)
const {
482 assert(layout.isForSubgroup() &&
"Must know the lane layout");
483 assert(reductionDims.size() == 2 &&
"Expected 2D reduction");
484 int64_t intra, cross = -1;
485 xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
486 if (
auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
488 dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
490 SmallVector<int64_t> laneLayout = layoutAttr.getEffectiveLaneLayoutAsInt();
492 assert(laneLayout.size() &&
"Expected a non-empty layout");
494 for (
auto dim : reductionDims) {
495 if (laneLayout[dim] == 1)
500 return {intra, cross};
508 patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
509 VectorExtractOpPattern, MultiRed2dOpPattern>(
515struct XeGPUPeepHoleOptimizerPass final
517 XeGPUPeepHoleOptimizerPass> {
518 void runOnOperation()
override {
526 bool isTargetSupported =
false;
527 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
529 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg"))
530 isTargetSupported =
true;
533 if (!isTargetSupported) {
534 DBGS() <<
"XeGPUPeepHoleOptimizerPass only supports PVC and BMG targets."
541 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
542 [&](xegpu::CreateNdDescOp createNdOp) {
543 return !canBeOptimizedForTranspose(createNdOp.getType());
545 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
546 [&](xegpu::LoadNdOp loadNdOp) {
547 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
552 target.addDynamicallyLegalOp<vector::ExtractOp>(
553 [&](vector::ExtractOp extractOp) {
555 dyn_cast<OpResult>(extractOp.getResult()));
558 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
559 auto laneData = layout.getEffectiveLaneDataAsInt();
560 return !canBeOptimizedForTranspose(laneLayout, laneData);
563 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
564 [=](Operation *op) ->
bool {
566 if (!layout || !layout.isForSubgroup())
568 if (
auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
569 return reductionOp.getReductionDims().size() != 2;
573 converter.addConversion([](Type type) {
return type; });
575 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
576 vector::VectorDialect>();
580 if (
failed(applyPartialConversion(getOperation(),
target,
582 DBGS() <<
"Optimize block loads pass failed.\n";
583 return signalPassFailure();
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUPeepHoleOptimizerPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
const Instruction * getInstruction(InstructionKind instKind) const