16#include "llvm/ADT/SmallVector.h"
18#define DEBUG_TYPE "xegpu-array-length-optimization"
26constexpr int64_t DEFAULT_SUBGROUP_SIZE = 16;
34 return DEFAULT_SUBGROUP_SIZE;
38 return DEFAULT_SUBGROUP_SIZE;
47 if (fcdSize <= subgroupSize)
49 return fcdSize / subgroupSize;
56static bool needsOptimization(xegpu::TensorDescType tdescType,
58 auto shape = tdescType.getShape();
59 if (
shape.size() != 2)
63 if (fcd % subgroupSize != 0)
66 return fcd > subgroupSize && tdescType.getArrayLength() == 1;
71static bool hasNonIdentityTranspose(xegpu::LoadNdOp loadOp) {
72 auto transpose = loadOp.getTranspose();
76 return !(perm.size() == 2 && perm[0] == 0 && perm[1] == 1);
84static bool hasTransposeLaneLayout(xegpu::TensorDescType tdescType) {
85 auto layout = tdescType.getLayoutAttr();
89 if (laneLayout.size() != 2)
91 return laneLayout[0] != 1 && laneLayout[1] == 1;
101class OptimizeCreateNdDescOp :
public OpRewritePattern<xegpu::CreateNdDescOp> {
105 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
107 int64_t subgroupSize = getSubgroupSize(op);
108 auto tdescType = op.getType();
109 if (!needsOptimization(tdescType, subgroupSize))
114 if (hasTransposeLaneLayout(tdescType))
117 Value source = op.getSource();
118 if (!isa<MemRefType, IntegerType>(source.
getType()))
123 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(user))
124 if (hasNonIdentityTranspose(loadOp))
128 auto shape = tdescType.getShape();
129 int64_t arrayLength = computeArrayLength(
shape[1], subgroupSize);
132 auto newTdescType = xegpu::TensorDescType::get(
133 newShape, tdescType.getElementType(), arrayLength,
134 tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
135 tdescType.getLayout());
140 auto newOp = xegpu::CreateNdDescOp::create(
141 rewriter, op.getLoc(), newTdescType, source, op.getMixedSizes(),
142 op.getMixedStrides());
143 rewriter.
replaceOp(op, newOp.getResult());
153 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
155 auto tdescType = op.getTensorDescType();
156 int64_t arrayLength = tdescType.getArrayLength();
158 if (arrayLength <= 1)
163 if (hasNonIdentityTranspose(op) || hasTransposeLaneLayout(tdescType))
166 auto origVectorType = op.getType();
167 auto origShape = origVectorType.getShape();
168 if (origShape.size() != 2)
172 int64_t expectedNonFCD = tdescType.getShape()[0] * arrayLength;
173 int64_t expectedFCD = tdescType.getShape()[1];
176 if (origShape[0] == expectedNonFCD && origShape[1] == expectedFCD)
182 VectorType::get(newShape, origVectorType.getElementType());
185 auto newLoadOp = xegpu::LoadNdOp::create(
186 rewriter, op.getLoc(), newVectorType, op.getTensorDesc(),
187 op.getMixedOffsets(), op.getPackedAttr(), op.getTransposeAttr(),
188 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
191 rewriter.
replaceOp(op, newLoadOp.getResult());
223class UpdateExtractStridedSliceOp
228 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
230 auto sourceType = dyn_cast<VectorType>(op.getSource().getType());
231 if (!sourceType || sourceType.getRank() != 2)
234 auto loadOp = op.getSource().getDefiningOp<xegpu::LoadNdOp>();
238 auto tdescType = loadOp.getTensorDescType();
239 int64_t arrayLength = tdescType.getArrayLength();
240 if (arrayLength <= 1)
243 auto offsets = op.getOffsets().getValue();
244 auto sizes = op.getSizes().getValue();
245 auto strides = op.getStrides().getValue();
247 if (offsets.size() != 2 || sizes.size() != 2 || strides.size() != 2)
250 int64_t origOffset0 = cast<IntegerAttr>(offsets[0]).getInt();
251 int64_t origOffset1 = cast<IntegerAttr>(offsets[1]).getInt();
253 int64_t blockHeight = tdescType.getShape()[0];
254 int64_t arrayWidth = tdescType.getShape()[1];
259 if (origOffset1 < arrayWidth)
264 assert(origOffset1 % arrayWidth == 0 &&
265 "extract offset along FCD must be a multiple of the array width");
267 int64_t arrayIndex = origOffset1 / arrayWidth;
272 return llvm::to_vector(llvm::map_range(
273 arr, [](
Attribute a) {
return cast<IntegerAttr>(a).getInt(); }));
278 auto newOp = vector::ExtractStridedSliceOp::create(
279 rewriter, op.getLoc(), op.getSource(), newOffsets, sliceSizes,
282 rewriter.
replaceOp(op, newOp.getResult());
291 patterns.
add<OptimizeCreateNdDescOp, OptimizeLoadNdOp,
292 UpdateExtractStridedSliceOp>(patterns.
getContext());
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUArrayLengthOptimizationPatterns(RewritePatternSet &patterns)
Appends patterns for array length optimization into patterns.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
virtual int getSubgroupSize() const =0