22 #include "llvm/ADT/TypeSwitch.h"
28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29 #include "mlir/Conversion/Passes.h.inc"
37 static bool isZeroConstant(
Value val) {
44 [](
auto floatAttr) {
return floatAttr.getValue().isZero(); })
46 [](
auto intAttr) {
return intAttr.getValue().isZero(); })
47 .Default([](
auto) {
return false; });
54 unsigned vecRank = vecTy.getRank();
55 if (!(vecRank == 1 || vecRank == 2))
62 VectorTransferOpInterface xferOp) {
65 "Masked transfer is not supported");
67 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
72 VectorType vecTy = xferOp.getVectorType();
73 if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
82 xferOp,
"Buffer must be contiguous in the innermost dimension");
84 unsigned vecRank = vecTy.getRank();
85 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
87 xferOp,
"Boundary check is available only for block instructions.");
94 auto dim = dyn_cast<AffineDimExpr>(expr);
95 if (dim.getPosition() < (numInputDims - vecRank))
97 xferOp,
"Only the innermost dimensions can be accessed");
103 static xegpu::CreateNdDescOp
107 MemRefType srcTy = src.getType();
110 xegpu::CreateNdDescOp ndDesc;
111 if (srcTy.hasStaticShape()) {
112 ndDesc = rewriter.
create<xegpu::CreateNdDescOp>(loc, descType, src,
118 unsigned srcRank = srcTy.getRank();
119 for (
unsigned i = 0; i < srcRank; ++i)
120 sourceDims.push_back(rewriter.
create<memref::DimOp>(loc, src, i));
124 for (
Value offset : offsets) {
127 dynOffsets.push_back(offset);
128 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
133 if (shape == ShapedType::kDynamic)
134 dynShapes.push_back(sourceDims[idx]);
139 Value accStride = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
141 for (
int i =
static_cast<int>(strides.size()) - 2; i >= 0; --i) {
143 rewriter.
create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
144 if (strides[i] == ShapedType::kDynamic)
145 dynStrides.push_back(accStride);
147 std::reverse(dynStrides.begin(), dynStrides.end());
149 ndDesc = rewriter.
create<xegpu::CreateNdDescOp>(
150 loc, descType, src, dynOffsets, dynShapes, dynStrides,
159 struct TransferReadLowering :
public OpRewritePattern<vector::TransferReadOp> {
162 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
166 if (failed(transferPreconditions(rewriter, readOp)))
169 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
170 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
172 readOp,
"Unsupported non-zero padded out-of-bounds read");
177 VectorType vecTy = readOp.getVectorType();
178 Type elementType = vecTy.getElementType();
179 unsigned minTransposeBitWidth = 32;
180 if (isTransposeLoad &&
183 readOp,
"Unsupported data type for tranposition");
188 std::reverse(descShape.begin(), descShape.end());
190 descShape, elementType, 1,
191 isOutOfBounds, xegpu::MemorySpace::Global);
193 xegpu::CreateNdDescOp ndDesc =
194 createNdDescriptor(rewriter, loc, descType,
196 readOp.getIndices());
199 !isTransposeLoad ? nullptr
203 xegpu::CachePolicyAttr hint =
nullptr;
204 auto loadOp = rewriter.
create<xegpu::LoadNdOp>(
205 loc, vecTy, ndDesc,
nullptr, transposeAttr,
214 struct TransferWriteLowering
218 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
222 if (failed(transferPreconditions(rewriter, writeOp)))
229 VectorType vecTy = writeOp.getVectorType();
231 vecTy.getShape(), vecTy.getElementType(),
232 1, writeOp.hasOutOfBoundsDim(),
233 xegpu::MemorySpace::Global);
234 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
235 rewriter, loc, descType,
237 writeOp.getIndices());
240 xegpu::CachePolicyAttr hint =
nullptr;
242 rewriter.
create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
254 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
258 VectorType vecTy = loadOp.getResult().getType();
259 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
263 bool boundaryCheck = vecTy.getRank() > 1;
266 vecTy.getShape(), vecTy.getElementType(), 1,
267 boundaryCheck, xegpu::MemorySpace::Global);
268 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
269 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
272 xegpu::CachePolicyAttr hint =
nullptr;
273 auto loadNdOp = rewriter.
create<xegpu::LoadNdOp>(
274 loc, vecTy, ndDesc,
nullptr,
nullptr,
286 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
291 VectorType vecTy = vector.getType();
292 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
296 bool boundaryCheck = vecTy.getRank() > 1;
299 vecTy.getShape(), vecTy.getElementType(),
300 1, boundaryCheck, xegpu::MemorySpace::Global);
301 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
302 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
305 xegpu::CachePolicyAttr hint =
nullptr;
307 rewriter.
create<xegpu::StoreNdOp>(loc, vector, ndDesc,
316 struct ConvertVectorToXeGPUPass
317 :
public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
318 void runOnOperation()
override {
322 return signalPassFailure();
330 patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
331 StoreLowering>(
patterns.getContext());
335 return std::make_unique<ConvertVectorToXeGPUPass>();
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::unique_ptr< Pass > createConvertVectorToXeGPUPass()
Create a pass to convert ops from vector to XeGPU.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...