MLIR  20.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to XeGPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 #include <algorithm>
25 #include <optional>
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 // Return true if value represents a zero constant.
37 static bool isZeroConstant(Value val) {
38  auto constant = val.getDefiningOp<arith::ConstantOp>();
39  if (!constant)
40  return false;
41 
42  return TypeSwitch<Attribute, bool>(constant.getValue())
43  .Case<FloatAttr>(
44  [](auto floatAttr) { return floatAttr.getValue().isZero(); })
45  .Case<IntegerAttr>(
46  [](auto intAttr) { return intAttr.getValue().isZero(); })
47  .Default([](auto) { return false; });
48 }
49 
50 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
51  Operation *op, VectorType vecTy) {
52  // Validate only vector as the basic vector store and load ops guarantee
53  // XeGPU-compatible memref source.
54  unsigned vecRank = vecTy.getRank();
55  if (!(vecRank == 1 || vecRank == 2))
56  return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
57 
58  return success();
59 }
60 
61 static LogicalResult transferPreconditions(PatternRewriter &rewriter,
62  VectorTransferOpInterface xferOp) {
63  if (xferOp.getMask())
64  return rewriter.notifyMatchFailure(xferOp,
65  "Masked transfer is not supported");
66 
67  auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
68  if (!srcTy)
69  return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
70 
71  // Perform common data transfer checks.
72  VectorType vecTy = xferOp.getVectorType();
73  if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
74  return failure();
75 
76  // Validate further transfer op semantics.
77  SmallVector<int64_t> strides;
78  int64_t offset;
79  if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
80  strides.back() != 1)
81  return rewriter.notifyMatchFailure(
82  xferOp, "Buffer must be contiguous in the innermost dimension");
83 
84  unsigned vecRank = vecTy.getRank();
85  AffineMap map = xferOp.getPermutationMap();
86  if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
87  return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
88  unsigned numInputDims = map.getNumInputs();
89  for (AffineExpr expr : map.getResults().take_back(vecRank)) {
90  auto dim = dyn_cast<AffineDimExpr>(expr);
91  if (dim.getPosition() < (numInputDims - vecRank))
92  return rewriter.notifyMatchFailure(
93  xferOp, "Only the innermost dimensions can be accessed");
94  }
95 
96  return success();
97 }
98 
99 static xegpu::CreateNdDescOp
100 createNdDescriptor(PatternRewriter &rewriter, Location loc,
101  xegpu::TensorDescType descType, TypedValue<MemRefType> src,
102  Operation::operand_range offsets) {
103  MemRefType srcTy = src.getType();
104  auto [strides, offset] = getStridesAndOffset(srcTy);
105 
106  xegpu::CreateNdDescOp ndDesc;
107  if (srcTy.hasStaticShape()) {
108  ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
109  getAsOpFoldResult(offsets));
110  } else {
111  // In case of any dynamic shapes, source's shape and strides have to be
112  // explicitly provided.
113  SmallVector<Value> sourceDims;
114  unsigned srcRank = srcTy.getRank();
115  for (unsigned i = 0; i < srcRank; ++i)
116  sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
117 
118  SmallVector<int64_t> constOffsets;
119  SmallVector<Value> dynOffsets;
120  for (Value offset : offsets) {
121  std::optional<int64_t> staticVal = getConstantIntValue(offset);
122  if (!staticVal)
123  dynOffsets.push_back(offset);
124  constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
125  }
126 
127  SmallVector<Value> dynShapes;
128  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
129  if (shape == ShapedType::kDynamic)
130  dynShapes.push_back(sourceDims[idx]);
131  }
132 
133  // Compute strides in reverse order.
134  SmallVector<Value> dynStrides;
135  Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
136  // Last stride is guaranteed to be static and unit.
137  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
138  accStride =
139  rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
140  if (strides[i] == ShapedType::kDynamic)
141  dynStrides.push_back(accStride);
142  }
143  std::reverse(dynStrides.begin(), dynStrides.end());
144 
145  ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
146  loc, descType, src, dynOffsets, dynShapes, dynStrides,
147  DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
148  DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
149  DenseI64ArrayAttr::get(rewriter.getContext(), strides));
150  }
151 
152  return ndDesc;
153 }
154 
155 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
157 
158  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
159  PatternRewriter &rewriter) const override {
160  Location loc = readOp.getLoc();
161 
162  if (failed(transferPreconditions(rewriter, readOp)))
163  return failure();
164 
165  bool isOutOfBounds = readOp.hasOutOfBoundsDim();
166  if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
167  return rewriter.notifyMatchFailure(
168  readOp, "Unsupported non-zero padded out-of-bounds read");
169 
170  AffineMap readMap = readOp.getPermutationMap();
171  bool isTransposeLoad = !readMap.isMinorIdentity();
172 
173  VectorType vecTy = readOp.getVectorType();
174  Type elementType = vecTy.getElementType();
175  unsigned minTransposeBitWidth = 32;
176  if (isTransposeLoad &&
177  elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
178  return rewriter.notifyMatchFailure(
179  readOp, "Unsupported data type for tranposition");
180 
181  // If load is transposed, get the base shape for the tensor descriptor.
182  SmallVector<int64_t> descShape{vecTy.getShape()};
183  if (isTransposeLoad)
184  std::reverse(descShape.begin(), descShape.end());
185  auto descType = xegpu::TensorDescType::get(
186  descShape, elementType, /*array_length=*/1,
187  /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
188 
189  xegpu::CreateNdDescOp ndDesc =
190  createNdDescriptor(rewriter, loc, descType,
191  dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
192  readOp.getIndices());
193 
194  DenseI64ArrayAttr transposeAttr =
195  !isTransposeLoad ? nullptr
196  : DenseI64ArrayAttr::get(rewriter.getContext(),
197  ArrayRef<int64_t>{1, 0});
198  // By default, no specific caching policy is assigned.
199  xegpu::CachePolicyAttr hint = nullptr;
200  auto loadOp = rewriter.create<xegpu::LoadNdOp>(
201  loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
202  /*l1_hint=*/hint,
203  /*l2_hint=*/hint, /*l3_hint=*/hint);
204  rewriter.replaceOp(readOp, loadOp);
205 
206  return success();
207  }
208 };
209 
210 struct TransferWriteLowering
211  : public OpRewritePattern<vector::TransferWriteOp> {
213 
214  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
215  PatternRewriter &rewriter) const override {
216  Location loc = writeOp.getLoc();
217 
218  if (failed(transferPreconditions(rewriter, writeOp)))
219  return failure();
220 
221  AffineMap map = writeOp.getPermutationMap();
222  if (!map.isMinorIdentity())
223  return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
224 
225  VectorType vecTy = writeOp.getVectorType();
226  auto descType = xegpu::TensorDescType::get(
227  vecTy.getShape(), vecTy.getElementType(),
228  /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
229  xegpu::MemorySpace::Global);
230  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
231  rewriter, loc, descType,
232  dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
233  writeOp.getIndices());
234 
235  // By default, no specific caching policy is assigned.
236  xegpu::CachePolicyAttr hint = nullptr;
237  auto storeOp =
238  rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
239  /*l1_hint=*/hint,
240  /*l2_hint=*/hint, /*l3_hint=*/hint);
241  rewriter.replaceOp(writeOp, storeOp);
242 
243  return success();
244  }
245 };
246 
247 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
249 
250  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
251  PatternRewriter &rewriter) const override {
252  Location loc = loadOp.getLoc();
253 
254  VectorType vecTy = loadOp.getResult().getType();
255  if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
256  return failure();
257 
258  auto descType = xegpu::TensorDescType::get(
259  vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
260  /*boundary_check=*/true, xegpu::MemorySpace::Global);
261  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
262  rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
263 
264  // By default, no specific caching policy is assigned.
265  xegpu::CachePolicyAttr hint = nullptr;
266  auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
267  loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
268  /*l1_hint=*/hint,
269  /*l2_hint=*/hint, /*l3_hint=*/hint);
270  rewriter.replaceOp(loadOp, loadNdOp);
271 
272  return success();
273  }
274 };
275 
276 struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
278 
279  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
280  PatternRewriter &rewriter) const override {
281  Location loc = storeOp.getLoc();
282 
283  TypedValue<VectorType> vector = storeOp.getValueToStore();
284  VectorType vecTy = vector.getType();
285  if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
286  return failure();
287 
288  auto descType =
289  xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(),
290  /*array_length=*/1, /*boundary_check=*/true,
291  xegpu::MemorySpace::Global);
292  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
293  rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
294 
295  // By default, no specific caching policy is assigned.
296  xegpu::CachePolicyAttr hint = nullptr;
297  auto storeNdOp =
298  rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
299  /*l1_hint=*/hint,
300  /*l2_hint=*/hint, /*l3_hint=*/hint);
301  rewriter.replaceOp(storeOp, storeNdOp);
302 
303  return success();
304  }
305 };
306 
307 struct ConvertVectorToXeGPUPass
308  : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
309  void runOnOperation() override {
310  RewritePatternSet patterns(&getContext());
312  if (failed(
313  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
314  return signalPassFailure();
315  }
316 };
317 
318 } // namespace
319 
321  RewritePatternSet &patterns) {
322  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
323  StoreLowering>(patterns.getContext());
324 }
325 
326 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
327  return std::make_unique<ConvertVectorToXeGPUPass>();
328 }
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::unique_ptr< Pass > createConvertVectorToXeGPUPass()
Create a pass to convert ops from vector to XeGPU.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358