29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
35#define GEN_PASS_DEF_XEGPUPEEPHOLEOPTIMIZER
36#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
40#define DEBUG_TYPE "xegpu-optimize-peephole"
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
48static std::optional<SmallVector<int64_t>>
49getMaybeLaneData(xegpu::TensorDescType tdescType) {
50 auto layout = tdescType.getLayoutAttr();
53 auto laneData = layout.getEffectiveLaneDataAsInt();
54 if (laneData.size() != 2)
60static std::optional<SmallVector<int64_t>>
61getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
62 auto layout = tdescType.getLayoutAttr();
65 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
66 if (laneLayout.size() != 2)
84 if (laneLayout.size() != 2 || laneData.size() != 2)
86 if (laneLayout[0] == 1 || laneLayout[1] != 1)
88 if (laneData[0] != 1 || laneData[1] == 1)
95static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
97 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
98 if (elementTyBitwidth >= 32)
100 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
101 auto maybeLaneData = getMaybeLaneData(tdescType);
102 if (!maybeLaneData || !maybeLaneLayout)
104 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
109static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
110 const uArch *targetuArch) {
111 if (!canBeOptimizedForTranspose(tdescType))
113 auto laneData = getMaybeLaneData(tdescType)
116 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
121 requiredShape.back() =
122 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
123 int newBitWidth = elementTyBitwidth * innerLaneData;
124 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
127 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
129 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
130 newElemTy,
false,
true);
134 auto [widths, heights, counts] = maybeHWParams.value();
136 if (counts.size() != 1 || counts[0] != 1)
138 int arrayLen = counts[0];
139 int supportedHeight =
144 if (supportedHeight == -1 || supportedWidth == -1)
148 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
149 tdescType.getContext(), tdescType.getLayoutAttr().getLaneLayout(),
151 tdescType.getLayoutAttr().getOrder());
153 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
154 tdescType.getBoundaryCheck(),
155 tdescType.getMemorySpace(), newLayout);
159static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
164 return llvm::cast<Value>(ofr);
168static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
171 if (llvm::isPowerOf2_64(constant)) {
172 int64_t shiftAmount = llvm::Log2_64(constant);
173 return arith::ShRUIOp::create(
181 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
187static Value generateLoads(ConversionPatternRewriter &rewriter,
191 xegpu::LoadNdOp origLoadOp) {
193 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
194 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
195 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
202 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
203 for (
int64_t w = 0; w < shapeRatio[1]; ++w) {
204 int64_t localOffsetDim0 = h * supportedShape[0];
205 int64_t localOffsetDim1 = w * supportedShape[1];
206 Value loadOffsetX = arith::AddIOp::create(
207 rewriter, loc, offsetDim0,
210 Value loadOffsetY = arith::AddIOp::create(
211 rewriter, loc, offsetDim1,
214 auto loadOp = xegpu::LoadNdOp::create(
216 VectorType::get(supportedShape, data.getType().getElementType()),
218 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
219 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
220 origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
222 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
223 loadOp.setAnchorLayout(layoutAttr);
225 auto insertOp = vector::InsertStridedSliceOp::create(
226 rewriter, loc, loadOp.getResult(), data,
231 data = insertOp.getResult();
241class XeGPUCreateNdDescOpPattern final
242 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
244 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
246 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter)
const override {
248 auto tdescTy = createNdOp.getType();
253 chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg") &&
254 "Expecting target chip to be pvc or bmg for transpose optimization.");
257 auto convertType = tryOptimize(tdescTy, targetuArch);
258 if (convertType == tdescTy)
260 auto strides = createNdOp.getMixedStrides();
263 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
264 return rewriter.notifyMatchFailure(
265 createNdOp,
"Expecting row-major memref for transpose optimization.");
266 Value source = createNdOp.getSource();
267 auto optionalLaneData = getMaybeLaneData(tdescTy);
268 assert(optionalLaneData &&
"Expected 2D lane data");
269 auto laneData = optionalLaneData.value();
270 int64_t innerLaneData = laneData[1];
271 auto memrefType = dyn_cast<MemRefType>(source.
getType());
273 SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
274 modifiedShape.back() = divideByConstant(
275 rewriter, createNdOp.getLoc(),
276 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
279 assert(strides.size() >= 2 &&
280 "Expected at least 2 strides for CreateNdDescOp");
281 SmallVector<OpFoldResult> modifiedStrides(strides);
282 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
283 rewriter, createNdOp.getLoc(),
284 convertToValue(rewriter, createNdOp.getLoc(),
285 modifiedStrides[modifiedStrides.size() - 2]),
290 if (memrefType && memrefType.hasStaticShape()) {
291 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
292 rewriter, createNdOp.getLoc(), source);
293 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
294 rewriter.getI64Type(),
295 extractOp.getResult())
299 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
300 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
302 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
311class XeGPULoadNdDescOpPattern final
312 :
public OpConversionPattern<xegpu::LoadNdOp> {
314 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
316 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
317 ConversionPatternRewriter &rewriter)
const override {
318 auto origTensorDescType = loadNdOp.getTensorDescType();
320 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
321 if (adaptorType == origTensorDescType)
324 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
325 int64_t innerLaneData = laneData[1];
326 auto offsets = loadNdOp.getMixedOffsets();
328 return rewriter.notifyMatchFailure(loadNdOp,
329 "Expecting offsets in LoadNd");
330 SmallVector<OpFoldResult> modifiedOffsets(offsets);
331 modifiedOffsets.back() = divideByConstant(
332 rewriter, loadNdOp.getLoc(),
333 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
337 SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
339 origDataShape.back() /= innerLaneData;
341 SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
342 VectorType origVectorType =
343 VectorType::get(origDataShape, adaptorType.getElementType());
346 if (origTensorDescType.getArrayLength() > 1) {
347 SmallVector<Value> arraySlices;
348 for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
349 Value slice = arith::ConstantOp::create(
350 rewriter, loadNdOp->getLoc(), origVectorType,
351 rewriter.getZeroAttr(origVectorType));
353 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
354 modifiedOffsets.back());
355 modifiedOffsets.back() =
356 arith::AddIOp::create(
357 rewriter, loadNdOp->getLoc(), offsetY,
359 i * origDataShape[1])
362 slice = generateLoads(
367 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
368 origTensorDescType.getElementType());
369 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
373 origTensorDescType.getLayoutAttr());
374 arraySlices.push_back(bitCastOp.getResult());
376 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
379 data = arith::ConstantOp::create(
380 rewriter, loadNdOp->getLoc(),
381 VectorType::get(origDataShape, adaptorType.getElementType()),
382 rewriter.getZeroAttr(origVectorType));
383 data = generateLoads(
387 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
388 loadNdOp.getType(), data);
391 origTensorDescType.getLayoutAttr());
392 rewriter.replaceOp(loadNdOp, bitCastOp);
401class VectorExtractOpPattern final
402 :
public OpConversionPattern<vector::ExtractOp> {
404 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
406 matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter)
const override {
409 if (adaptor.getSource().size() == 1)
411 auto mixedPos = extractOp.getMixedPosition();
412 if (mixedPos.size() != 1)
417 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
424class MultiRed2dOpPattern
425 :
public OpConversionPattern<vector::MultiDimReductionOp> {
426 using OpConversionPattern::OpConversionPattern;
428 matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter)
const override {
430 auto sourceVecType = reductionOp.getSourceVectorType();
431 if (reductionOp.getReductionDims().size() != 2)
432 return rewriter.notifyMatchFailure(reductionOp,
"Expected 2D reduction");
435 auto dims = llvm::to_vector(reductionOp.getReductionDims());
436 auto [intraLaneDim, crossLaneDim] = getReductionDimOrder(dims, resLayout);
438 if (intraLaneDim == -1 || crossLaneDim == -1) {
439 intraLaneDim = dims[0];
440 crossLaneDim = dims[1];
442 auto loc = reductionOp.getLoc();
443 auto acc = reductionOp.getAcc();
445 SmallVector<int64_t> accShape(sourceVecType.getShape());
446 accShape.erase(accShape.begin() + intraLaneDim);
447 Type eTy = sourceVecType.getElementType();
449 rewriter, loc, VectorType::get(accShape, eTy), reductionOp.getKind());
451 Value intraLaneReduced = vector::MultiDimReductionOp::create(
452 rewriter, loc, reductionOp.getKind(), reductionOp.getSource(),
453 constNeutralVal, ArrayRef<int64_t>(intraLaneDim));
456 if (crossLaneDim > intraLaneDim)
458 Value crossLaneReduced = vector::MultiDimReductionOp::create(
459 rewriter, loc, reductionOp.getKind(), intraLaneReduced, acc,
460 ArrayRef<int64_t>(crossLaneDim));
461 assert(crossLaneReduced.
getType() == reductionOp.getResult().getType() &&
463 rewriter.replaceOp(reductionOp, crossLaneReduced);
468 std::pair<int64_t, int64_t>
469 getReductionDimOrder(ArrayRef<int64_t> reductionDims,
470 xegpu::DistributeLayoutAttr layout)
const {
471 assert(layout.isForSubgroup() &&
"Must know the lane layout");
472 assert(reductionDims.size() == 2 &&
"Expected 2D reduction");
473 int64_t intra, cross = -1;
474 xegpu::LayoutAttr layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout);
475 if (
auto layoutSliceAttr = dyn_cast<xegpu::SliceAttr>(layout))
477 dyn_cast<xegpu::LayoutAttr>(layoutSliceAttr.flatten().getParent());
479 SmallVector<int64_t> laneLayout = layoutAttr.getEffectiveLaneLayoutAsInt();
481 assert(laneLayout.size() &&
"Expected a non-empty layout");
483 for (
auto dim : reductionDims) {
484 if (laneLayout[dim] == 1)
489 return {intra, cross};
497 patterns.
add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
498 VectorExtractOpPattern, MultiRed2dOpPattern>(
504struct XeGPUPeepHoleOptimizerPass final
506 XeGPUPeepHoleOptimizerPass> {
507 void runOnOperation()
override {
515 bool isTargetSupported =
false;
516 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
518 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg"))
519 isTargetSupported =
true;
522 if (!isTargetSupported) {
523 DBGS() <<
"XeGPUPeepHoleOptimizerPass only supports PVC and BMG targets."
530 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
531 [&](xegpu::CreateNdDescOp createNdOp) {
532 return !canBeOptimizedForTranspose(createNdOp.getType());
534 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
535 [&](xegpu::LoadNdOp loadNdOp) {
536 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
541 target.addDynamicallyLegalOp<vector::ExtractOp>(
542 [&](vector::ExtractOp extractOp) {
544 dyn_cast<OpResult>(extractOp.getResult()));
547 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
548 auto laneData = layout.getEffectiveLaneDataAsInt();
549 return !canBeOptimizedForTranspose(laneLayout, laneData);
552 target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
553 [=](Operation *op) ->
bool {
555 if (!layout || !layout.isForSubgroup())
557 if (
auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op))
558 return reductionOp.getReductionDims().size() != 2;
562 converter.addConversion([](Type type) {
return type; });
564 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
565 vector::VectorDialect>();
569 if (
failed(applyPartialConversion(getOperation(),
target,
570 std::move(patterns)))) {
571 DBGS() <<
"Optimize block loads pass failed.\n";
572 return signalPassFailure();
577 RewritePatternSet emptyPatterns(ctx);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
void populateXeGPUPeepHoleOptimizerPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
const Instruction * getInstruction(InstructionKind instKind) const