16#include "llvm/ADT/SmallVector.h"
18#define DEBUG_TYPE "xegpu-array-length-optimization"
26constexpr int64_t DEFAULT_SUBGROUP_SIZE = 16;
34 return DEFAULT_SUBGROUP_SIZE;
38 return DEFAULT_SUBGROUP_SIZE;
47 if (fcdSize <= subgroupSize)
49 return fcdSize / subgroupSize;
56static bool needsOptimization(xegpu::TensorDescType tdescType,
58 auto shape = tdescType.getShape();
59 if (
shape.size() != 2)
63 if (fcd % subgroupSize != 0)
66 return fcd > subgroupSize && tdescType.getArrayLength() == 1;
71static bool hasNonIdentityTranspose(xegpu::LoadNdOp loadOp) {
72 auto transpose = loadOp.getTranspose();
76 return !(perm.size() == 2 && perm[0] == 0 && perm[1] == 1);
84static bool hasTransposeLaneLayout(xegpu::TensorDescType tdescType) {
85 auto layout = tdescType.getLayoutAttr();
89 if (laneLayout.size() != 2)
91 return laneLayout[0] != 1 && laneLayout[1] == 1;
101class OptimizeCreateNdDescOp :
public OpRewritePattern<xegpu::CreateNdDescOp> {
105 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
108 if (op.getType().getElementTypeBitWidth() < 8)
110 int64_t subgroupSize = getSubgroupSize(op);
111 auto tdescType = op.getType();
112 if (!needsOptimization(tdescType, subgroupSize))
117 if (hasTransposeLaneLayout(tdescType))
120 Value source = op.getSource();
121 if (!isa<MemRefType, IntegerType>(source.
getType()))
126 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(user))
127 if (hasNonIdentityTranspose(loadOp))
131 auto shape = tdescType.getShape();
132 int64_t arrayLength = computeArrayLength(
shape[1], subgroupSize);
135 auto newTdescType = xegpu::TensorDescType::get(
136 newShape, tdescType.getElementType(), arrayLength,
137 tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
138 tdescType.getLayout());
143 auto newOp = xegpu::CreateNdDescOp::create(
144 rewriter, op.getLoc(), newTdescType, source, op.getMixedSizes(),
145 op.getMixedStrides());
146 rewriter.
replaceOp(op, newOp.getResult());
156 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
158 auto tdescType = op.getTensorDescType();
159 int64_t arrayLength = tdescType.getArrayLength();
161 if (arrayLength <= 1)
166 if (hasNonIdentityTranspose(op) || hasTransposeLaneLayout(tdescType))
169 auto origVectorType = op.getType();
170 auto origShape = origVectorType.getShape();
171 if (origShape.size() != 2)
175 int64_t expectedNonFCD = tdescType.getShape()[0] * arrayLength;
176 int64_t expectedFCD = tdescType.getShape()[1];
179 if (origShape[0] == expectedNonFCD && origShape[1] == expectedFCD)
185 VectorType::get(newShape, origVectorType.getElementType());
188 auto newLoadOp = xegpu::LoadNdOp::create(
189 rewriter, op.getLoc(), newVectorType, op.getTensorDesc(),
190 op.getMixedOffsets(), op.getPackedAttr(), op.getTransposeAttr(),
191 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
194 rewriter.
replaceOp(op, newLoadOp.getResult());
226class UpdateExtractStridedSliceOp
231 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
233 auto sourceType = dyn_cast<VectorType>(op.getSource().getType());
234 if (!sourceType || sourceType.getRank() != 2)
237 auto loadOp = op.getSource().getDefiningOp<xegpu::LoadNdOp>();
241 auto tdescType = loadOp.getTensorDescType();
242 int64_t arrayLength = tdescType.getArrayLength();
243 if (arrayLength <= 1)
246 auto offsets = op.getOffsets().getValue();
247 auto sizes = op.getSizes().getValue();
248 auto strides = op.getStrides().getValue();
250 if (offsets.size() != 2 || sizes.size() != 2 || strides.size() != 2)
253 int64_t origOffset0 = cast<IntegerAttr>(offsets[0]).getInt();
254 int64_t origOffset1 = cast<IntegerAttr>(offsets[1]).getInt();
256 int64_t blockHeight = tdescType.getShape()[0];
257 int64_t arrayWidth = tdescType.getShape()[1];
262 if (origOffset1 < arrayWidth)
267 assert(origOffset1 % arrayWidth == 0 &&
268 "extract offset along FCD must be a multiple of the array width");
270 int64_t arrayIndex = origOffset1 / arrayWidth;
275 return llvm::to_vector(llvm::map_range(
276 arr, [](
Attribute a) {
return cast<IntegerAttr>(a).getInt(); }));
281 auto newOp = vector::ExtractStridedSliceOp::create(
282 rewriter, op.getLoc(), op.getSource(), newOffsets, sliceSizes,
285 rewriter.
replaceOp(op, newOp.getResult());
294 patterns.
add<OptimizeCreateNdDescOp, OptimizeLoadNdOp,
295 UpdateExtractStridedSliceOp>(patterns.
getContext());
Attributes are known-constant values of operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
const uArch * getUArch(llvm::StringRef archName)
void populateXeGPUArrayLengthOptimizationPatterns(RewritePatternSet &patterns)
Appends patterns for array length optimization into patterns.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
virtual int getSubgroupSize() const =0