21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
27 #define GEN_PASS_DEF_XEGPUUNROLL
28 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
32 #define DEBUG_TYPE "xegpu-unroll"
33 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
34 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
40 template <
typename SourceOp>
51 LDBG(
"Get unroll shape for: " << *op);
53 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
54 LDBG(
"--no filter constraint -> BAIL");
59 "expects the native shape for native shape call back function.");
60 auto nativeShape =
options.nativeShape(op);
66 return options.getUnrolledTypes(type, tileShape);
73 if (
auto vecTy = dyn_cast<VectorType>(destTy)) {
74 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
75 "Expecting blockSize size to match the rank of destTy.");
76 auto shape = vecTy.getShape();
77 auto zeroAttr = rewriter.
getZeroAttr(vecTy.getElementType());
81 for (
auto [src, offsets] :
84 result = rewriter.
create<vector::InsertStridedSliceOp>(
85 loc, src, result, offsets, staticStrides);
90 if (isa<xegpu::TensorDescType>(destTy)) {
95 auto castOp = rewriter.
create<UnrealizedConversionCastOp>(
97 return castOp.getResult(0);
100 llvm_unreachable(
"Unexpected destTy.");
109 if (
auto vecTy = dyn_cast<VectorType>(src.
getType())) {
110 assert(vecTy.getRank() ==
static_cast<int64_t
>(blockSize.size()) &&
111 "Expecting blockSize size to match the rank of src.");
112 auto shape = vecTy.getShape();
117 auto slice = rewriter.
create<vector::ExtractStridedSliceOp>(
118 loc, src, offsets, blockSize, staticStrides);
119 results.push_back(slice);
124 if (isa<xegpu::TensorDescType>(src.
getType())) {
129 auto castOp = rewriter.
create<UnrealizedConversionCastOp>(
131 return castOp.getResults();
134 llvm_unreachable(
"Unexpected src type.");
139 const char *
const packAttrName =
"__xegpu_blocking_pack__";
140 const char *
const unpackAttrName =
"__xegpu_blocking_unpack__";
141 const char *
const blockAttrName =
"__xegpu_blocking_tile_shape__";
146 struct UnrollCreateNdOp :
public UnrollPattern<xegpu::CreateNdDescOp> {
147 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
148 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
151 xegpu::TensorDescType tdescTy = op.getType();
152 int64_t rank = tdescTy.getRank();
155 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
156 if (!targetShape || llvm::equal(*targetShape, shape))
159 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
164 return rewriter.
create<arith::ConstantIndexOp>(loc, *maybeInt + b);
166 auto aV = llvm::cast<Value>(a);
167 auto bV = rewriter.
create<arith::ConstantIndexOp>(loc, b);
168 return rewriter.
createOrFold<arith::AddIOp>(loc, aV, bV);
177 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
179 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
185 for (
auto [idx, oldOff, offset] :
186 llvm::zip(validIdxes, oldOffsets, offsets))
187 mixedOffsets[idx] = addi(oldOff, offset);
189 auto newOp = rewriter.
create<xegpu::CreateNdDescOp>(
190 loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
191 op.getMixedStrides());
192 newOps.push_back(newOp);
194 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
201 struct UnrollUpdateNdOffsetOp :
public UnrollPattern<xegpu::UpdateNdOffsetOp> {
202 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
203 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
206 xegpu::TensorDescType tdescTy = op.getTensorDescType();
209 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
210 if (!targetShape || llvm::equal(*targetShape, shape))
214 getUnrolledTypes(tdescTy, *targetShape);
216 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
219 for (
auto t : convertedTdesc) {
220 auto newOp = rewriter.
create<xegpu::UpdateNdOffsetOp>(
221 loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
222 newOps.push_back(newOp);
224 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
230 struct UnrollPrefetchNdOp :
public UnrollPattern<xegpu::PrefetchNdOp> {
231 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
232 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
235 xegpu::TensorDescType tdescTy = op.getTensorDescType();
238 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
239 if (!targetShape || llvm::equal(*targetShape, shape))
243 getUnrolledTypes(tdescTy, *targetShape);
245 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
247 for (
auto t : convertedTdesc)
248 rewriter.
create<xegpu::PrefetchNdOp>(loc,
TypeRange(), t, op->getAttrs());
255 struct UnrollLoadNdOp :
public UnrollPattern<xegpu::LoadNdOp> {
256 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
257 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
261 VectorType valueTy = op.getType();
262 xegpu::TensorDescType tdescTy = op.getTensorDescType();
265 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
266 if (!targetShape || llvm::equal(*targetShape, shape))
269 Type elemTy = tdescTy.getElementType();
270 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
273 getUnrolledTypes(tdescTy, *targetShape);
275 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
278 for (
auto t : convertedTdescs) {
280 rewriter.
create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
281 newOps.push_back(newOp);
284 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
291 struct UnrollStoreNdOp :
public UnrollPattern<xegpu::StoreNdOp> {
292 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
293 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
296 VectorType valueTy = op.getValueType();
297 xegpu::TensorDescType tdescTy = op.getTensorDescType();
300 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
301 if (!targetShape || llvm::equal(*targetShape, shape))
305 getUnrolledTypes(valueTy, *targetShape);
307 getUnrolledTypes(tdescTy, *targetShape);
310 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
312 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
314 for (
auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
315 rewriter.
create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
316 op.getL2HintAttr(), op.getL3HintAttr());
323 struct UnrollDpasOp :
public UnrollPattern<xegpu::DpasOp> {
324 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
325 LogicalResult matchAndRewrite(xegpu::DpasOp op,
330 if (llvm::any_of(op->getOperandTypes(), [&](
Type type) {
331 auto vecTy = dyn_cast<VectorType>(type);
332 return !vecTy || vecTy.getRank() != 2;
338 std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(op);
339 if (!targetShape || targetShape->size() != 3)
341 auto M = (*targetShape)[0];
342 auto K = (*targetShape)[1];
343 auto N = (*targetShape)[2];
345 int64_t aBlockSize[2] = {M, K};
346 int64_t bBlockSize[2] = {K, N};
347 int64_t cBlockSize[2] = {M, N};
351 VectorType type = val.getType();
352 std::optional<SmallVector<int64_t>> grids =
354 assert(grids &&
"Expecting grids to be computed.");
358 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
361 pack(val, convertedTypes, blockSize, loc, rewriter);
365 auto a = op.getLhs();
366 auto b = op.getRhs();
367 auto c = op.getAcc();
369 auto aShape = a.getType().getShape();
370 auto bShape = b.getType().getShape();
373 aVals = packWrapper(a, aBlockSize);
374 bVals = packWrapper(b, bBlockSize);
377 cVals = packWrapper(c, cBlockSize);
383 if (llvm::any_of(ranges, [](
auto &v) {
return v.size() == 0; }) ||
384 llvm::all_of(ranges, [](
auto &v) {
return v.size() == 1; }))
387 VectorType resultTy = op.getResult().getType();
390 int64_t mIters = aShape[0] / M;
391 int64_t kIters = aShape[1] / K;
392 int64_t nIters = bShape[1] / N;
395 for (int64_t i = 0; i < mIters; ++i) {
396 for (int64_t
j = 0;
j < nIters; ++
j) {
399 tmpC = cVals[i * nIters +
j];
401 for (int64_t k = 0; k < kIters; ++k) {
402 Value aVec = aVals[i * kIters + k];
403 Value bVec = bVals[k * nIters +
j];
406 operands.push_back(tmpC);
408 tmpC = rewriter.
create<xegpu::DpasOp>(loc, vecTy, operands,
411 newOps.push_back(tmpC);
414 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
424 patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
425 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.