28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallVector.h"
34#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
35#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
39#define DEBUG_TYPE "xegpu-optimize-peephole"
40#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
47static std::optional<SmallVector<int64_t>>
48getMaybeLaneData(xegpu::TensorDescType tdescType) {
49 auto layout = tdescType.getLayoutAttr();
52 auto laneData = layout.getEffectiveLaneDataAsInt();
53 if (laneData.size() != 2)
59static std::optional<SmallVector<int64_t>>
60getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
61 auto layout = tdescType.getLayoutAttr();
64 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
65 if (laneLayout.size() != 2)
83 if (laneLayout.size() != 2 || laneData.size() != 2)
85 if (laneLayout[0] == 1 || laneLayout[1] != 1)
87 if (laneData[0] != 1 || laneData[1] == 1)
94static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
96 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
97 if (elementTyBitwidth >= 32)
99 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
100 auto maybeLaneData = getMaybeLaneData(tdescType);
101 if (!maybeLaneData || !maybeLaneLayout)
103 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
108static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
110 if (!canBeOptimizedForTranspose(tdescType))
112 auto laneData = getMaybeLaneData(tdescType)
115 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
120 requiredShape.back() =
121 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
122 int newBitWidth = elementTyBitwidth * innerLaneData;
123 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
126 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
128 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
129 newElemTy,
false,
true);
133 auto [widths, heights, counts] = maybeHWParams.value();
135 if (counts.size() != 1 || counts[0] != 1)
137 int arrayLen = counts[0];
138 int supportedHeight =
143 if (supportedHeight == -1 || supportedWidth == -1)
147 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
148 tdescType.getContext(), tdescType.getLayoutAttr().getLaneLayout(),
150 tdescType.getLayoutAttr().getOrder());
152 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
153 tdescType.getBoundaryCheck(),
154 tdescType.getMemorySpace(), newLayout);
158static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
163 return llvm::cast<Value>(ofr);
167static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
170 if (llvm::isPowerOf2_64(constant)) {
171 int64_t shiftAmount = llvm::Log2_64(constant);
172 return arith::ShRUIOp::create(
180 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
186static Value generateLoads(ConversionPatternRewriter &rewriter,
190 xegpu::LoadNdOp origLoadOp) {
192 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
193 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
194 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
201 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
202 for (
int64_t w = 0; w < shapeRatio[1]; ++w) {
203 int64_t localOffsetDim0 = h * supportedShape[0];
204 int64_t localOffsetDim1 = w * supportedShape[1];
205 Value loadOffsetX = arith::AddIOp::create(
206 rewriter, loc, offsetDim0,
209 Value loadOffsetY = arith::AddIOp::create(
210 rewriter, loc, offsetDim1,
213 auto loadOp = xegpu::LoadNdOp::create(
215 VectorType::get(supportedShape, data.getType().getElementType()),
217 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
218 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
219 origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
221 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
222 loadOp.setAnchorLayout(layoutAttr);
224 auto insertOp = vector::InsertStridedSliceOp::create(
225 rewriter, loc, loadOp.getResult(), data,
230 data = insertOp.getResult();
240class XeGPUCreateNdDescOpPattern final
241 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
243 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
245 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 auto tdescTy = createNdOp.getType();
252 chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg") &&
253 "Expecting target chip to be pvc or bmg for transpose optimization.");
256 auto convertType = tryOptimize(tdescTy, targetuArch);
257 if (convertType == tdescTy)
259 auto strides = createNdOp.getMixedStrides();
262 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
263 return rewriter.notifyMatchFailure(
264 createNdOp,
"Expecting row-major memref for transpose optimization.");
265 Value source = createNdOp.getSource();
266 auto optionalLaneData = getMaybeLaneData(tdescTy);
267 assert(optionalLaneData &&
"Expected 2D lane data");
268 auto laneData = optionalLaneData.value();
269 int64_t innerLaneData = laneData[1];
270 auto memrefType = dyn_cast<MemRefType>(source.
getType());
272 SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
273 modifiedShape.back() = divideByConstant(
274 rewriter, createNdOp.getLoc(),
275 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
278 assert(strides.size() >= 2 &&
279 "Expected at least 2 strides for CreateNdDescOp");
280 SmallVector<OpFoldResult> modifiedStrides(strides);
281 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
282 rewriter, createNdOp.getLoc(),
283 convertToValue(rewriter, createNdOp.getLoc(),
284 modifiedStrides[modifiedStrides.size() - 2]),
289 if (memrefType && memrefType.hasStaticShape()) {
290 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
291 rewriter, createNdOp.getLoc(), source);
292 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
293 rewriter.getI64Type(),
294 extractOp.getResult())
298 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
299 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
301 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
310class XeGPULoadNdDescOpPattern final
311 :
public OpConversionPattern<xegpu::LoadNdOp> {
313 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
315 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter)
const override {
317 auto origTensorDescType = loadNdOp.getTensorDescType();
319 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
320 if (adaptorType == origTensorDescType)
323 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
324 int64_t innerLaneData = laneData[1];
325 auto offsets = loadNdOp.getMixedOffsets();
327 return rewriter.notifyMatchFailure(loadNdOp,
328 "Expecting offsets in LoadNd");
329 SmallVector<OpFoldResult> modifiedOffsets(offsets);
330 modifiedOffsets.back() = divideByConstant(
331 rewriter, loadNdOp.getLoc(),
332 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
336 SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
338 origDataShape.back() /= innerLaneData;
340 SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
341 VectorType origVectorType =
342 VectorType::get(origDataShape, adaptorType.getElementType());
345 if (origTensorDescType.getArrayLength() > 1) {
346 SmallVector<Value> arraySlices;
347 for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
348 Value slice = arith::ConstantOp::create(
349 rewriter, loadNdOp->getLoc(), origVectorType,
350 rewriter.getZeroAttr(origVectorType));
352 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
353 modifiedOffsets.back());
354 modifiedOffsets.back() =
355 arith::AddIOp::create(
356 rewriter, loadNdOp->getLoc(), offsetY,
358 i * origDataShape[1])
361 slice = generateLoads(
366 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
367 origTensorDescType.getElementType());
368 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
372 origTensorDescType.getLayoutAttr());
373 arraySlices.push_back(bitCastOp.getResult());
375 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
378 data = arith::ConstantOp::create(
379 rewriter, loadNdOp->getLoc(),
380 VectorType::get(origDataShape, adaptorType.getElementType()),
381 rewriter.getZeroAttr(origVectorType));
382 data = generateLoads(
386 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
387 loadNdOp.getType(), data);
390 origTensorDescType.getLayoutAttr());
391 rewriter.replaceOp(loadNdOp, bitCastOp);
400class VectorExtractOpPattern final
401 :
public OpConversionPattern<vector::ExtractOp> {
403 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
405 matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
406 ConversionPatternRewriter &rewriter)
const override {
408 if (adaptor.getSource().size() == 1)
410 auto mixedPos = extractOp.getMixedPosition();
411 if (mixedPos.size() != 1)
416 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
423class MultiRed2dOpPattern
424 :
public OpConversionPattern<vector::MultiDimReductionOp> {
425 using OpConversionPattern::OpConversionPattern;
427 matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter)
const override {
429 auto sourceVecType = reductionOp.getSourceVectorType();
430 if (reductionOp.getReductionDims().size() != 2 ||
431 sourceVecType.getRank() != 2)
432 return rewriter.notifyMatchFailure(
433 reductionOp,
"Expected 2D multi reduction of a 2D source");
436 auto dims = llvm::to_vector(reductionOp.getReductionDims());
437 auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
439 if (intraLaneDim == -1 || crossLaneDim == -1) {
440 intraLaneDim = dims[0];
441 crossLaneDim = dims[1];
443 auto loc = reductionOp.getLoc();
444 auto acc = reductionOp.getAcc();
447 auto resSliceLayoutAttr = cast<xegpu::SliceAttr>(resLayout);
448 SmallVector<int64_t>
dropDims{crossLaneDim};
449 auto intraLaneRedResLayout = resSliceLayoutAttr.dropSliceDims(
dropDims);
451 SmallVector<int64_t> accShape(sourceVecType.getShape());
452 accShape.erase(accShape.begin() + intraLaneDim);
454 acc = vector::BroadcastOp::create(
456 VectorType::get(accShape, sourceVecType.getElementType()), acc);
458 llvm::dyn_cast<OpResult>(acc),
459 cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
461 Value intraLaneReduced = vector::MultiDimReductionOp::create(
462 rewriter, loc, reductionOp.getKind(), reductionOp.getSource(), acc,
463 ArrayRef<int64_t>(intraLaneDim));
465 llvm::dyn_cast<OpResult>(intraLaneReduced),
466 cast<xegpu::DistributeLayoutAttr>(intraLaneRedResLayout));
468 Value crossLaneReduced = vector::ReductionOp::create(
469 rewriter, loc, reductionOp.getKind(), intraLaneReduced,
nullptr);
471 llvm::dyn_cast<OpResult>(crossLaneReduced),
472 cast<xegpu::DistributeLayoutAttr>(resLayout));
473 assert(crossLaneReduced.
getType() == reductionOp.getResult().getType() &&
475 rewriter.replaceOp(reductionOp, crossLaneReduced);
480 std::pair<int64_t, int64_t>
481 getReductionDimOrder(ArrayRef<int64_t> reductionDims,
482 xegpu::DistributeLayoutAttr layout)
const {
483 assert(layout.isForSubgroup() &&
"Must know the lane layout");
484 assert(reductionDims.size() == 2 &&
"Expected 2D reduction");
485 int64_t intra, cross = -1;
486 xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
487 if (
auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
489 dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
491 SmallVector<int64_t> laneLayout = layoutAttr.getEffectiveLaneLayoutAsInt();
493 assert(laneLayout.size() &&
"Expected a non-empty layout");
495 for (
auto dim : reductionDims) {
496 if (laneLayout[dim] == 1)
501 return {intra, cross};
509 patterns.
add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
510 VectorExtractOpPattern, MultiRed2dOpPattern>(
516struct XeGPUPeepHoleOptimizerPass final
518 XeGPUPeepHoleOptimizerPass> {
519 void runOnOperation()
override {
527 bool isTargetSupported =
false;
528 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
530 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg"))
531 isTargetSupported =
true;
534 if (!isTargetSupported) {
535 DBGS() <<
"XeGPUPeepHoleOptimizerPass only supports PVC and BMG targets."
542 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
543 [&](xegpu::CreateNdDescOp createNdOp) {
544 return !canBeOptimizedForTranspose(createNdOp.getType());
546 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
547 [&](xegpu::LoadNdOp loadNdOp) {
548 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
553 target.addDynamicallyLegalOp<vector::ExtractOp>(
554 [&](vector::ExtractOp extractOp) {
556 dyn_cast<OpResult>(extractOp.getResult()));
559 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
560 auto laneData = layout.getEffectiveLaneDataAsInt();
561 return !canBeOptimizedForTranspose(laneLayout, laneData);
564 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
565 [=](Operation *op) ->
bool {
567 if (!layout || !layout.isForSubgroup())
569 if (
auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
570 return reductionOp.getReductionDims().size() != 2;
574 converter.addConversion([](Type type) {
return type; });
576 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
577 vector::VectorDialect>();
581 if (
failed(applyPartialConversion(getOperation(),
target,
582 std::move(patterns)))) {
583 DBGS() <<
"Optimize block loads pass failed.\n";
584 return signalPassFailure();
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUPeepHoleOptimizerPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
const Instruction * getInstruction(InstructionKind instKind) const