22 #include "llvm/ADT/TypeSwitch.h"
28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29 #include "mlir/Conversion/Passes.h.inc"
37 static bool isZeroConstant(
Value val) {
44 [](
auto floatAttr) {
return floatAttr.getValue().isZero(); })
46 [](
auto intAttr) {
return intAttr.getValue().isZero(); })
47 .Default([](
auto) {
return false; });
54 unsigned vecRank = vecTy.getRank();
55 if (!(vecRank == 1 || vecRank == 2))
62 VectorTransferOpInterface xferOp) {
65 "Masked transfer is not supported");
67 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
72 VectorType vecTy = xferOp.getVectorType();
73 if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
82 xferOp,
"Buffer must be contiguous in the innermost dimension");
84 unsigned vecRank = vecTy.getRank();
90 auto dim = dyn_cast<AffineDimExpr>(expr);
91 if (dim.getPosition() < (numInputDims - vecRank))
93 xferOp,
"Only the innermost dimensions can be accessed");
99 static xegpu::CreateNdDescOp
103 MemRefType srcTy = src.getType();
106 xegpu::CreateNdDescOp ndDesc;
107 if (srcTy.hasStaticShape()) {
108 ndDesc = rewriter.
create<xegpu::CreateNdDescOp>(loc, descType, src,
114 unsigned srcRank = srcTy.getRank();
115 for (
unsigned i = 0; i < srcRank; ++i)
116 sourceDims.push_back(rewriter.
create<memref::DimOp>(loc, src, i));
120 for (
Value offset : offsets) {
123 dynOffsets.push_back(offset);
124 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
129 if (shape == ShapedType::kDynamic)
130 dynShapes.push_back(sourceDims[idx]);
135 Value accStride = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
137 for (
int i =
static_cast<int>(strides.size()) - 2; i >= 0; --i) {
139 rewriter.
create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
140 if (strides[i] == ShapedType::kDynamic)
141 dynStrides.push_back(accStride);
143 std::reverse(dynStrides.begin(), dynStrides.end());
145 ndDesc = rewriter.
create<xegpu::CreateNdDescOp>(
146 loc, descType, src, dynOffsets, dynShapes, dynStrides,
155 struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
158 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
162 if (failed(transferPreconditions(rewriter, readOp)))
165 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
166 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
168 readOp,
"Unsupported non-zero padded out-of-bounds read");
173 VectorType vecTy = readOp.getVectorType();
174 Type elementType = vecTy.getElementType();
175 unsigned minTransposeBitWidth = 32;
176 if (isTransposeLoad &&
179 readOp,
"Unsupported data type for tranposition");
184 std::reverse(descShape.begin(), descShape.end());
186 descShape, elementType, 1,
187 isOutOfBounds, xegpu::MemorySpace::Global);
189 xegpu::CreateNdDescOp ndDesc =
190 createNdDescriptor(rewriter, loc, descType,
192 readOp.getIndices());
195 !isTransposeLoad ? nullptr
199 xegpu::CachePolicyAttr hint =
nullptr;
200 auto loadOp = rewriter.
create<xegpu::LoadNdOp>(
201 loc, vecTy, ndDesc,
nullptr, transposeAttr,
210 struct TransferWriteLowering
214 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
218 if (failed(transferPreconditions(rewriter, writeOp)))
225 VectorType vecTy = writeOp.getVectorType();
227 vecTy.getShape(), vecTy.getElementType(),
228 1, writeOp.hasOutOfBoundsDim(),
229 xegpu::MemorySpace::Global);
230 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
231 rewriter, loc, descType,
233 writeOp.getIndices());
236 xegpu::CachePolicyAttr hint =
nullptr;
238 rewriter.
create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
250 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
254 VectorType vecTy = loadOp.getResult().getType();
255 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
259 vecTy.getShape(), vecTy.getElementType(), 1,
260 true, xegpu::MemorySpace::Global);
261 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
262 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
265 xegpu::CachePolicyAttr hint =
nullptr;
266 auto loadNdOp = rewriter.
create<xegpu::LoadNdOp>(
267 loc, vecTy, ndDesc,
nullptr,
nullptr,
279 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
284 VectorType vecTy = vector.getType();
285 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
291 xegpu::MemorySpace::Global);
292 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
293 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
296 xegpu::CachePolicyAttr hint =
nullptr;
298 rewriter.
create<xegpu::StoreNdOp>(loc, vector, ndDesc,
307 struct ConvertVectorToXeGPUPass
308 :
public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
309 void runOnOperation()
override {
314 return signalPassFailure();
322 patterns.
add<TransferReadLowering, TransferWriteLowering, LoadLowering,
327 return std::make_unique<ConvertVectorToXeGPUPass>();
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::unique_ptr< Pass > createConvertVectorToXeGPUPass()
Create a pass to convert ops from vector to XeGPU.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...