27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallVector.h"
33#define GEN_PASS_DEF_XEGPUOPTIMIZEBLOCKLOADS
34#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38#define DEBUG_TYPE "xegpu-optimize-block-loads"
39#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
46static std::optional<SmallVector<int64_t>>
47getMaybeLaneData(xegpu::TensorDescType tdescType) {
48 auto layout = tdescType.getLayoutAttr();
51 auto laneData = layout.getEffectiveLaneDataAsInt();
52 if (laneData.size() != 2)
58static std::optional<SmallVector<int64_t>>
59getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
60 auto layout = tdescType.getLayoutAttr();
63 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
64 if (laneLayout.size() != 2)
82 if (laneLayout.size() != 2 || laneData.size() != 2)
84 if (laneLayout[0] == 1 || laneLayout[1] != 1)
86 if (laneData[0] != 1 || laneData[1] == 1)
93static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
95 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
96 if (elementTyBitwidth >= 32)
98 auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
99 auto maybeLaneData = getMaybeLaneData(tdescType);
100 if (!maybeLaneData || !maybeLaneLayout)
102 return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
107static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
108 const uArch *targetuArch) {
109 if (!canBeOptimizedForTranspose(tdescType))
111 auto laneData = getMaybeLaneData(tdescType)
113 int64_t innerLaneData = laneData[1];
114 int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
119 requiredShape.back() =
120 requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
121 int newBitWidth = elementTyBitwidth * innerLaneData;
122 Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
125 auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
126 targetuArch->
getInstruction(InstructionKind::Subgroup2DBlockLoad));
127 auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
128 newElemTy,
false,
true);
132 auto [widths, heights, counts] = maybeHWParams.value();
134 if (counts.size() != 1 || counts[0] != 1)
136 int arrayLen = counts[0];
137 int supportedHeight =
142 if (supportedHeight == -1 || supportedWidth == -1)
146 xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
147 tdescType.getContext(),
148 tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
150 return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
151 tdescType.getBoundaryCheck(),
152 tdescType.getMemorySpace(), newLayout);
156static Value convertToValue(ConversionPatternRewriter &rewriter,
Location loc,
161 return llvm::cast<Value>(ofr);
165static Value divideByConstant(ConversionPatternRewriter &rewriter,
Location loc,
168 if (llvm::isPowerOf2_64(constant)) {
169 int64_t shiftAmount = llvm::Log2_64(constant);
170 return arith::ShRUIOp::create(
178 return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
184static Value generateLoads(ConversionPatternRewriter &rewriter,
188 xegpu::LoadNdOp origLoadOp) {
190 assert(offsets.size() >= 2 &&
"Expecting at least 2 offsets for 2D LoadNdOp");
191 Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
192 Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
199 for (
int64_t h = 0; h < shapeRatio[0]; ++h) {
201 int64_t localOffsetDim0 = h * supportedShape[0];
202 int64_t localOffsetDim1 = w * supportedShape[1];
203 Value loadOffsetX = arith::AddIOp::create(
204 rewriter, loc, offsetDim0,
207 Value loadOffsetY = arith::AddIOp::create(
208 rewriter, loc, offsetDim1,
211 auto loadOp = xegpu::LoadNdOp::create(
213 VectorType::get(supportedShape, data.getType().getElementType()),
215 origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
216 origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
217 origLoadOp.getL3HintAttr());
219 auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
222 auto insertOp = vector::InsertStridedSliceOp::create(
223 rewriter, loc, loadOp.getResult(), data,
228 data = insertOp.getResult();
238class XeGPUCreateNdDescOpPattern final
239 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
241 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
243 matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
244 ConversionPatternRewriter &rewriter)
const override {
245 auto tdescTy = createNdOp.getType();
250 chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg") &&
251 "Expecting target chip to be pvc or bmg for transpose optimization.");
254 auto convertType = tryOptimize(tdescTy, targetuArch);
255 if (convertType == tdescTy)
257 auto strides = createNdOp.getMixedStrides();
260 if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
261 return rewriter.notifyMatchFailure(
262 createNdOp,
"Expecting row-major memref for transpose optimization.");
263 Value source = createNdOp.getSource();
264 auto optionalLaneData = getMaybeLaneData(tdescTy);
265 assert(optionalLaneData &&
"Expected 2D lane data");
266 auto laneData = optionalLaneData.value();
267 int64_t innerLaneData = laneData[1];
268 auto memrefType = dyn_cast<MemRefType>(source.
getType());
271 modifiedShape.back() = divideByConstant(
272 rewriter, createNdOp.getLoc(),
273 convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
276 assert(strides.size() >= 2 &&
277 "Expected at least 2 strides for CreateNdDescOp");
279 modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
280 rewriter, createNdOp.getLoc(),
281 convertToValue(rewriter, createNdOp.getLoc(),
282 modifiedStrides[modifiedStrides.size() - 2]),
287 if (memrefType && memrefType.hasStaticShape()) {
288 auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
289 rewriter, createNdOp.getLoc(), source);
290 source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
291 rewriter.getI64Type(),
292 extractOp.getResult())
296 auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
297 rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
299 rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
308class XeGPULoadNdDescOpPattern final
309 :
public OpConversionPattern<xegpu::LoadNdOp> {
311 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
313 matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter)
const override {
315 auto origTensorDescType = loadNdOp.getTensorDescType();
317 cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
318 if (adaptorType == origTensorDescType)
321 auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
322 int64_t innerLaneData = laneData[1];
323 auto offsets = loadNdOp.getMixedOffsets();
325 return rewriter.notifyMatchFailure(loadNdOp,
326 "Expecting offsets in LoadNd");
328 modifiedOffsets.back() = divideByConstant(
329 rewriter, loadNdOp.getLoc(),
330 convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
336 origDataShape.back() /= innerLaneData;
339 VectorType origVectorType =
340 VectorType::get(origDataShape, adaptorType.getElementType());
343 if (origTensorDescType.getArrayLength() > 1) {
345 for (
int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
346 Value slice = arith::ConstantOp::create(
347 rewriter, loadNdOp->getLoc(), origVectorType,
348 rewriter.getZeroAttr(origVectorType));
350 Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
351 modifiedOffsets.back());
352 modifiedOffsets.back() =
353 arith::AddIOp::create(
354 rewriter, loadNdOp->getLoc(), offsetY,
356 i * origDataShape[1])
359 slice = generateLoads(
364 auto bitcastType = VectorType::get(origTensorDescType.getShape(),
365 origTensorDescType.getElementType());
366 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
370 origTensorDescType.getLayoutAttr());
371 arraySlices.push_back(bitCastOp.getResult());
373 rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
376 data = arith::ConstantOp::create(
377 rewriter, loadNdOp->getLoc(),
378 VectorType::get(origDataShape, adaptorType.getElementType()),
379 rewriter.getZeroAttr(origVectorType));
380 data = generateLoads(
384 auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
385 loadNdOp.getType(), data);
388 origTensorDescType.getLayoutAttr());
389 rewriter.replaceOp(loadNdOp, bitCastOp);
398class VectorExtractOpPattern final
399 :
public OpConversionPattern<vector::ExtractOp> {
401 using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
403 matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter)
const override {
406 if (adaptor.getSource().size() == 1)
408 auto mixedPos = extractOp.getMixedPosition();
409 if (mixedPos.size() != 1)
414 rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
423 patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
424 VectorExtractOpPattern>(
patterns.getContext());
429struct XeGPUOptimizeBlockLoadsPass final
431 XeGPUOptimizeBlockLoadsPass> {
432 void runOnOperation()
override {
440 bool isTargetSupported =
false;
441 getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
443 if (chipStr && (chipStr.value() ==
"pvc" || chipStr.value() ==
"bmg"))
444 isTargetSupported =
true;
447 if (!isTargetSupported) {
448 DBGS() <<
"XeGPUOptimizeBlockLoadsPass only supports PVC and BMG targets."
455 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
456 [&](xegpu::CreateNdDescOp createNdOp) {
457 return !canBeOptimizedForTranspose(createNdOp.getType());
459 target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
460 [&](xegpu::LoadNdOp loadNdOp) {
461 return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
466 target.addDynamicallyLegalOp<vector::ExtractOp>(
467 [&](vector::ExtractOp extractOp) {
471 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
472 auto laneData = layout.getEffectiveLaneDataAsInt();
473 return !canBeOptimizedForTranspose(laneLayout, laneData);
475 converter.addConversion([](Type type) {
return type; });
477 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
478 vector::VectorDialect>();
482 if (
failed(applyPartialConversion(getOperation(),
target,
484 DBGS() <<
"Optimize block loads pass failed.\n";
485 return signalPassFailure();
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUOptimizeBlockLoadsPatterns(RewritePatternSet &patterns)
Appends patterns for optimizing block load operations into patterns.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
const Instruction * getInstruction(InstructionKind instKind) const