MLIR  20.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to XeGPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 #include <algorithm>
25 #include <optional>
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 // Return true if value represents a zero constant.
37 static bool isZeroConstant(Value val) {
38  auto constant = val.getDefiningOp<arith::ConstantOp>();
39  if (!constant)
40  return false;
41 
42  return TypeSwitch<Attribute, bool>(constant.getValue())
43  .Case<FloatAttr>(
44  [](auto floatAttr) { return floatAttr.getValue().isZero(); })
45  .Case<IntegerAttr>(
46  [](auto intAttr) { return intAttr.getValue().isZero(); })
47  .Default([](auto) { return false; });
48 }
49 
50 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
51  Operation *op, VectorType vecTy) {
52  // Validate only vector as the basic vector store and load ops guarantee
53  // XeGPU-compatible memref source.
54  unsigned vecRank = vecTy.getRank();
55  if (!(vecRank == 1 || vecRank == 2))
56  return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
57 
58  return success();
59 }
60 
61 static LogicalResult transferPreconditions(PatternRewriter &rewriter,
62  VectorTransferOpInterface xferOp) {
63  if (xferOp.getMask())
64  return rewriter.notifyMatchFailure(xferOp,
65  "Masked transfer is not supported");
66 
67  auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
68  if (!srcTy)
69  return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
70 
71  // Perform common data transfer checks.
72  VectorType vecTy = xferOp.getVectorType();
73  if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
74  return failure();
75 
76  // Validate further transfer op semantics.
77  SmallVector<int64_t> strides;
78  int64_t offset;
79  if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
80  strides.back() != 1)
81  return rewriter.notifyMatchFailure(
82  xferOp, "Buffer must be contiguous in the innermost dimension");
83 
84  unsigned vecRank = vecTy.getRank();
85  if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
86  return rewriter.notifyMatchFailure(
87  xferOp, "Boundary check is available only for block instructions.");
88 
89  AffineMap map = xferOp.getPermutationMap();
90  if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
91  return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
92  unsigned numInputDims = map.getNumInputs();
93  for (AffineExpr expr : map.getResults().take_back(vecRank)) {
94  auto dim = dyn_cast<AffineDimExpr>(expr);
95  if (dim.getPosition() < (numInputDims - vecRank))
96  return rewriter.notifyMatchFailure(
97  xferOp, "Only the innermost dimensions can be accessed");
98  }
99 
100  return success();
101 }
102 
103 static xegpu::CreateNdDescOp
104 createNdDescriptor(PatternRewriter &rewriter, Location loc,
105  xegpu::TensorDescType descType, TypedValue<MemRefType> src,
106  Operation::operand_range offsets) {
107  MemRefType srcTy = src.getType();
108  auto [strides, offset] = getStridesAndOffset(srcTy);
109 
110  xegpu::CreateNdDescOp ndDesc;
111  if (srcTy.hasStaticShape()) {
112  ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
113  getAsOpFoldResult(offsets));
114  } else {
115  // In case of any dynamic shapes, source's shape and strides have to be
116  // explicitly provided.
117  SmallVector<Value> sourceDims;
118  unsigned srcRank = srcTy.getRank();
119  for (unsigned i = 0; i < srcRank; ++i)
120  sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
121 
122  SmallVector<int64_t> constOffsets;
123  SmallVector<Value> dynOffsets;
124  for (Value offset : offsets) {
125  std::optional<int64_t> staticVal = getConstantIntValue(offset);
126  if (!staticVal)
127  dynOffsets.push_back(offset);
128  constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
129  }
130 
131  SmallVector<Value> dynShapes;
132  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
133  if (shape == ShapedType::kDynamic)
134  dynShapes.push_back(sourceDims[idx]);
135  }
136 
137  // Compute strides in reverse order.
138  SmallVector<Value> dynStrides;
139  Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
140  // Last stride is guaranteed to be static and unit.
141  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
142  accStride =
143  rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
144  if (strides[i] == ShapedType::kDynamic)
145  dynStrides.push_back(accStride);
146  }
147  std::reverse(dynStrides.begin(), dynStrides.end());
148 
149  ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
150  loc, descType, src, dynOffsets, dynShapes, dynStrides,
151  DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
152  DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
153  DenseI64ArrayAttr::get(rewriter.getContext(), strides));
154  }
155 
156  return ndDesc;
157 }
158 
159 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
161 
162  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
163  PatternRewriter &rewriter) const override {
164  Location loc = readOp.getLoc();
165 
166  if (failed(transferPreconditions(rewriter, readOp)))
167  return failure();
168 
169  bool isOutOfBounds = readOp.hasOutOfBoundsDim();
170  if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
171  return rewriter.notifyMatchFailure(
172  readOp, "Unsupported non-zero padded out-of-bounds read");
173 
174  AffineMap readMap = readOp.getPermutationMap();
175  bool isTransposeLoad = !readMap.isMinorIdentity();
176 
177  VectorType vecTy = readOp.getVectorType();
178  Type elementType = vecTy.getElementType();
179  unsigned minTransposeBitWidth = 32;
180  if (isTransposeLoad &&
181  elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
182  return rewriter.notifyMatchFailure(
183  readOp, "Unsupported data type for tranposition");
184 
185  // If load is transposed, get the base shape for the tensor descriptor.
186  SmallVector<int64_t> descShape{vecTy.getShape()};
187  if (isTransposeLoad)
188  std::reverse(descShape.begin(), descShape.end());
189  auto descType = xegpu::TensorDescType::get(
190  descShape, elementType, /*array_length=*/1,
191  /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
192 
193  xegpu::CreateNdDescOp ndDesc =
194  createNdDescriptor(rewriter, loc, descType,
195  dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
196  readOp.getIndices());
197 
198  DenseI64ArrayAttr transposeAttr =
199  !isTransposeLoad ? nullptr
200  : DenseI64ArrayAttr::get(rewriter.getContext(),
201  ArrayRef<int64_t>{1, 0});
202  // By default, no specific caching policy is assigned.
203  xegpu::CachePolicyAttr hint = nullptr;
204  auto loadOp = rewriter.create<xegpu::LoadNdOp>(
205  loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
206  /*l1_hint=*/hint,
207  /*l2_hint=*/hint, /*l3_hint=*/hint);
208  rewriter.replaceOp(readOp, loadOp);
209 
210  return success();
211  }
212 };
213 
214 struct TransferWriteLowering
215  : public OpRewritePattern<vector::TransferWriteOp> {
217 
218  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
219  PatternRewriter &rewriter) const override {
220  Location loc = writeOp.getLoc();
221 
222  if (failed(transferPreconditions(rewriter, writeOp)))
223  return failure();
224 
225  AffineMap map = writeOp.getPermutationMap();
226  if (!map.isMinorIdentity())
227  return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
228 
229  VectorType vecTy = writeOp.getVectorType();
230  auto descType = xegpu::TensorDescType::get(
231  vecTy.getShape(), vecTy.getElementType(),
232  /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
233  xegpu::MemorySpace::Global);
234  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
235  rewriter, loc, descType,
236  dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
237  writeOp.getIndices());
238 
239  // By default, no specific caching policy is assigned.
240  xegpu::CachePolicyAttr hint = nullptr;
241  auto storeOp =
242  rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
243  /*l1_hint=*/hint,
244  /*l2_hint=*/hint, /*l3_hint=*/hint);
245  rewriter.replaceOp(writeOp, storeOp);
246 
247  return success();
248  }
249 };
250 
251 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
253 
254  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
255  PatternRewriter &rewriter) const override {
256  Location loc = loadOp.getLoc();
257 
258  VectorType vecTy = loadOp.getResult().getType();
259  if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
260  return failure();
261 
262  // Boundary check is available only for block instructions.
263  bool boundaryCheck = vecTy.getRank() > 1;
264 
265  auto descType = xegpu::TensorDescType::get(
266  vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
267  boundaryCheck, xegpu::MemorySpace::Global);
268  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
269  rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
270 
271  // By default, no specific caching policy is assigned.
272  xegpu::CachePolicyAttr hint = nullptr;
273  auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
274  loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
275  /*l1_hint=*/hint,
276  /*l2_hint=*/hint, /*l3_hint=*/hint);
277  rewriter.replaceOp(loadOp, loadNdOp);
278 
279  return success();
280  }
281 };
282 
283 struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
285 
286  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
287  PatternRewriter &rewriter) const override {
288  Location loc = storeOp.getLoc();
289 
290  TypedValue<VectorType> vector = storeOp.getValueToStore();
291  VectorType vecTy = vector.getType();
292  if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
293  return failure();
294 
295  // Boundary check is available only for block instructions.
296  bool boundaryCheck = vecTy.getRank() > 1;
297 
298  auto descType = xegpu::TensorDescType::get(
299  vecTy.getShape(), vecTy.getElementType(),
300  /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
301  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
302  rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
303 
304  // By default, no specific caching policy is assigned.
305  xegpu::CachePolicyAttr hint = nullptr;
306  auto storeNdOp =
307  rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
308  /*l1_hint=*/hint,
309  /*l2_hint=*/hint, /*l3_hint=*/hint);
310  rewriter.replaceOp(storeOp, storeNdOp);
311 
312  return success();
313  }
314 };
315 
316 struct ConvertVectorToXeGPUPass
317  : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
318  void runOnOperation() override {
321  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
322  return signalPassFailure();
323  }
324 };
325 
326 } // namespace
327 
330  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
331  StoreLowering>(patterns.getContext());
332 }
333 
334 std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
335  return std::make_unique<ConvertVectorToXeGPUPass>();
336 }
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::unique_ptr< Pass > createConvertVectorToXeGPUPass()
Create a pass to convert ops from vector to XeGPU.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358