MLIR  21.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to XeGPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
20 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/Passes.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 #include <algorithm>
26 #include <optional>
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 namespace {
36 
37 // Return true if value represents a zero constant.
38 static bool isZeroConstant(Value val) {
39  auto constant = val.getDefiningOp<arith::ConstantOp>();
40  if (!constant)
41  return false;
42 
43  return TypeSwitch<Attribute, bool>(constant.getValue())
44  .Case<FloatAttr>(
45  [](auto floatAttr) { return floatAttr.getValue().isZero(); })
46  .Case<IntegerAttr>(
47  [](auto intAttr) { return intAttr.getValue().isZero(); })
48  .Default([](auto) { return false; });
49 }
50 
51 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
52  Operation *op, VectorType vecTy) {
53  // Validate only vector as the basic vector store and load ops guarantee
54  // XeGPU-compatible memref source.
55  unsigned vecRank = vecTy.getRank();
56  if (!(vecRank == 1 || vecRank == 2))
57  return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
58 
59  return success();
60 }
61 
62 static LogicalResult transferPreconditions(PatternRewriter &rewriter,
63  VectorTransferOpInterface xferOp) {
64  if (xferOp.getMask())
65  return rewriter.notifyMatchFailure(xferOp,
66  "Masked transfer is not supported");
67 
68  auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
69  if (!srcTy)
70  return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
71 
72  // Perform common data transfer checks.
73  VectorType vecTy = xferOp.getVectorType();
74  if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
75  return failure();
76 
77  // Validate further transfer op semantics.
78  SmallVector<int64_t> strides;
79  int64_t offset;
80  if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
81  return rewriter.notifyMatchFailure(
82  xferOp, "Buffer must be contiguous in the innermost dimension");
83 
84  unsigned vecRank = vecTy.getRank();
85  if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
86  return rewriter.notifyMatchFailure(
87  xferOp, "Boundary check is available only for block instructions.");
88 
89  AffineMap map = xferOp.getPermutationMap();
90  if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
91  return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
92  unsigned numInputDims = map.getNumInputs();
93  for (AffineExpr expr : map.getResults().take_back(vecRank)) {
94  auto dim = dyn_cast<AffineDimExpr>(expr);
95  if (dim.getPosition() < (numInputDims - vecRank))
96  return rewriter.notifyMatchFailure(
97  xferOp, "Only the innermost dimensions can be accessed");
98  }
99 
100  return success();
101 }
102 
103 static xegpu::CreateNdDescOp
104 createNdDescriptor(PatternRewriter &rewriter, Location loc,
105  xegpu::TensorDescType descType, TypedValue<MemRefType> src,
106  Operation::operand_range offsets) {
107  MemRefType srcTy = src.getType();
108  auto [strides, offset] = srcTy.getStridesAndOffset();
109 
110  xegpu::CreateNdDescOp ndDesc;
111  if (srcTy.hasStaticShape()) {
112  ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
113  getAsOpFoldResult(offsets));
114  } else {
115  // In case of any dynamic shapes, source's shape and strides have to be
116  // explicitly provided.
117  SmallVector<Value> sourceDims;
118  unsigned srcRank = srcTy.getRank();
119  for (unsigned i = 0; i < srcRank; ++i)
120  sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
121 
122  SmallVector<int64_t> constOffsets;
123  SmallVector<Value> dynOffsets;
124  for (Value offset : offsets) {
125  std::optional<int64_t> staticVal = getConstantIntValue(offset);
126  if (!staticVal)
127  dynOffsets.push_back(offset);
128  constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
129  }
130 
131  SmallVector<Value> dynShapes;
132  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
133  if (shape == ShapedType::kDynamic)
134  dynShapes.push_back(sourceDims[idx]);
135  }
136 
137  // Compute strides in reverse order.
138  SmallVector<Value> dynStrides;
139  Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
140  // Last stride is guaranteed to be static and unit.
141  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
142  accStride =
143  rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
144  if (strides[i] == ShapedType::kDynamic)
145  dynStrides.push_back(accStride);
146  }
147  std::reverse(dynStrides.begin(), dynStrides.end());
148 
149  ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
150  loc, descType, src, dynOffsets, dynShapes, dynStrides,
151  DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
152  DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
153  DenseI64ArrayAttr::get(rewriter.getContext(), strides));
154  }
155 
156  return ndDesc;
157 }
158 
159 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
161 
162  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
163  PatternRewriter &rewriter) const override {
164  Location loc = readOp.getLoc();
165 
166  if (failed(transferPreconditions(rewriter, readOp)))
167  return failure();
168 
169  bool isOutOfBounds = readOp.hasOutOfBoundsDim();
170  if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
171  return rewriter.notifyMatchFailure(
172  readOp, "Unsupported non-zero padded out-of-bounds read");
173 
174  AffineMap readMap = readOp.getPermutationMap();
175  bool isTransposeLoad = !readMap.isMinorIdentity();
176 
177  VectorType vecTy = readOp.getVectorType();
178  Type elementType = vecTy.getElementType();
179  unsigned minTransposeBitWidth = 32;
180  if (isTransposeLoad &&
181  elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
182  return rewriter.notifyMatchFailure(
183  readOp, "Unsupported data type for transposition");
184 
185  // If load is transposed, get the base shape for the tensor descriptor.
186  SmallVector<int64_t> descShape(vecTy.getShape());
187  if (isTransposeLoad)
188  std::reverse(descShape.begin(), descShape.end());
189  auto descType = xegpu::TensorDescType::get(
190  descShape, elementType, /*array_length=*/1,
191  /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
192 
193  xegpu::CreateNdDescOp ndDesc =
194  createNdDescriptor(rewriter, loc, descType,
195  dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
196  readOp.getIndices());
197 
198  DenseI64ArrayAttr transposeAttr =
199  !isTransposeLoad ? nullptr
200  : DenseI64ArrayAttr::get(rewriter.getContext(),
201  ArrayRef<int64_t>{1, 0});
202  // By default, no specific caching policy is assigned.
203  xegpu::CachePolicyAttr hint = nullptr;
204  auto loadOp = rewriter.create<xegpu::LoadNdOp>(
205  loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
206  /*l1_hint=*/hint,
207  /*l2_hint=*/hint, /*l3_hint=*/hint);
208  rewriter.replaceOp(readOp, loadOp);
209 
210  return success();
211  }
212 };
213 
214 struct TransferWriteLowering
215  : public OpRewritePattern<vector::TransferWriteOp> {
217 
218  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
219  PatternRewriter &rewriter) const override {
220  Location loc = writeOp.getLoc();
221 
222  if (failed(transferPreconditions(rewriter, writeOp)))
223  return failure();
224 
225  AffineMap map = writeOp.getPermutationMap();
226  if (!map.isMinorIdentity())
227  return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
228 
229  VectorType vecTy = writeOp.getVectorType();
230  auto descType = xegpu::TensorDescType::get(
231  vecTy.getShape(), vecTy.getElementType(),
232  /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
233  xegpu::MemorySpace::Global);
234  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
235  rewriter, loc, descType,
236  dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
237  writeOp.getIndices());
238 
239  // By default, no specific caching policy is assigned.
240  xegpu::CachePolicyAttr hint = nullptr;
241  auto storeOp =
242  rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
243  /*l1_hint=*/hint,
244  /*l2_hint=*/hint, /*l3_hint=*/hint);
245  rewriter.replaceOp(writeOp, storeOp);
246 
247  return success();
248  }
249 };
250 
251 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
253 
254  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
255  PatternRewriter &rewriter) const override {
256  Location loc = loadOp.getLoc();
257 
258  VectorType vecTy = loadOp.getResult().getType();
259  if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
260  return failure();
261 
262  // Boundary check is available only for block instructions.
263  bool boundaryCheck = vecTy.getRank() > 1;
264 
265  auto descType = xegpu::TensorDescType::get(
266  vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
267  boundaryCheck, xegpu::MemorySpace::Global);
268  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
269  rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
270 
271  // By default, no specific caching policy is assigned.
272  xegpu::CachePolicyAttr hint = nullptr;
273  auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
274  loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
275  /*l1_hint=*/hint,
276  /*l2_hint=*/hint, /*l3_hint=*/hint);
277  rewriter.replaceOp(loadOp, loadNdOp);
278 
279  return success();
280  }
281 };
282 
283 struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
285 
286  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
287  PatternRewriter &rewriter) const override {
288  Location loc = storeOp.getLoc();
289 
290  TypedValue<VectorType> vector = storeOp.getValueToStore();
291  VectorType vecTy = vector.getType();
292  if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
293  return failure();
294 
295  // Boundary check is available only for block instructions.
296  bool boundaryCheck = vecTy.getRank() > 1;
297 
298  auto descType = xegpu::TensorDescType::get(
299  vecTy.getShape(), vecTy.getElementType(),
300  /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
301  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
302  rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
303 
304  // By default, no specific caching policy is assigned.
305  xegpu::CachePolicyAttr hint = nullptr;
306  auto storeNdOp =
307  rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
308  /*l1_hint=*/hint,
309  /*l2_hint=*/hint, /*l3_hint=*/hint);
310  rewriter.replaceOp(storeOp, storeNdOp);
311 
312  return success();
313  }
314 };
315 
316 struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
318 
319  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
320  PatternRewriter &rewriter) const override {
321  Location loc = contractOp.getLoc();
322 
323  if (contractOp.getKind() != vector::CombiningKind::ADD)
324  return rewriter.notifyMatchFailure(contractOp,
325  "Expects add combining kind");
326 
327  TypedValue<Type> acc = contractOp.getAcc();
328  VectorType accType = dyn_cast<VectorType>(acc.getType());
329  if (!accType || accType.getRank() != 2)
330  return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
331 
332  // Accept only plain 2D data layout.
333  // VNNI packing is applied to DPAS as a separate lowering step.
334  TypedValue<VectorType> lhs = contractOp.getLhs();
335  TypedValue<VectorType> rhs = contractOp.getRhs();
336  if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
337  return rewriter.notifyMatchFailure(contractOp,
338  "Expects lhs and rhs 2D vectors");
339 
340  if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
341  return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
342 
343  // TODO: Update shape validation to be target aware.
344  auto accShape = accType.getShape();
345  int64_t dimN = accShape[1];
346  if (dimN != 8 && dimN != 16)
347  return rewriter.notifyMatchFailure(contractOp,
348  "Invalid operand dimensions");
349 
350  auto dpasOp = rewriter.create<xegpu::DpasOp>(
351  loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
352  rewriter.replaceOp(contractOp, dpasOp);
353 
354  return success();
355  }
356 };
357 
358 struct ConvertVectorToXeGPUPass
359  : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
360  void runOnOperation() override {
363  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
364  return signalPassFailure();
365  }
366 };
367 
368 } // namespace
369 
372  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
373  StoreLowering, ContractionLowering>(patterns.getContext());
374 }
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult transferPreconditions(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool &requiresBroadcasting, VectorType &unbroadcastedVectorType)
This pattern supports lowering of: vector.transfer_read to a combination of vector....
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:155
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358