MLIR  14.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
16 
17 #include "../PassDetail.h"
22 #include "mlir/Dialect/SCF/SCF.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/Passes.h"
30 
31 using namespace mlir;
32 
33 // Return true if the contract op can be convert to MMA matmul.
34 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
35  if (llvm::size(contract.masks()) != 0)
36  return false;
37 
38  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
39  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
40  AffineExpr m, n, k;
41  bindDims(contract.getContext(), m, n, k);
42  auto iteratorTypes = contract.iterator_types().getValue();
43  if (!(isParallelIterator(iteratorTypes[0]) &&
44  isParallelIterator(iteratorTypes[1]) &&
45  isReductionIterator(iteratorTypes[2])))
46  return false;
47 
48  // The contract needs to represent a matmul to be able to convert to
49  // MMAMatrix matmul.
50  if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
51  return false;
52 
53  return true;
54 }
55 
56 // Return the stide for the dimension 0 of |type| if it is a memref and has a
57 // constant stride.
60  auto memrefType = type.dyn_cast<MemRefType>();
61  if (!memrefType)
62  return false;
63  int64_t offset = 0;
65  if (failed(getStridesAndOffset(memrefType, strides, offset)))
66  return llvm::None;
67  if (strides[0] == ShapedType::kDynamicStrideOrOffset)
68  return llvm::None;
69  return strides[0];
70 }
71 
72 // Return true if the transfer op can be converted to a MMA matrix load.
73 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
74  if (readOp.mask() || readOp.hasOutOfBoundsDim() ||
75  readOp.getVectorType().getRank() != 2)
76  return false;
77  if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
78  return false;
79  AffineMap map = readOp.permutation_map();
80  OpBuilder b(readOp.getContext());
81  AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
82  AffineExpr zero = b.getAffineConstantExpr(0);
83  auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
84  readOp.getContext());
85  // TODO: Support transpose once it is added to GPU dialect ops.
86  // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
87  return !(!map.isMinorIdentity() && map != broadcastInnerDim);
88 }
89 
90 // Return true if the transfer op can be converted to a MMA matrix store.
91 static bool
92 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
93  // TODO: support 0-d corner case.
94  if (writeOp.getTransferRank() == 0)
95  return false;
96 
97  if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
98  writeOp.getVectorType().getRank() != 2)
99  return false;
100  if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
101  return false;
102  // TODO: Support transpose once it is added to GPU dialect ops.
103  if (!writeOp.permutation_map().isMinorIdentity())
104  return false;
105  return true;
106 }
107 
108 /// Return true if the constant is a splat to a 2D vector so that it can be
109 /// converted to a MMA constant matrix op.
110 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
111  auto vecType = constantOp.getType().dyn_cast<VectorType>();
112  if (!vecType || vecType.getRank() != 2)
113  return false;
114  return constantOp.getValue().isa<SplatElementsAttr>();
115 }
116 
117 /// Return true if this is a broadcast from scalar to a 2D vector.
118 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
119  return broadcastOp.getVectorType().getRank() == 2 &&
120  broadcastOp.source().getType().isa<FloatType>();
121 }
122 
123 /// Return the MMA elementwise enum associated with `op` if it is supported.
124 /// Return `llvm::None` otherwise.
127  if (isa<arith::AddFOp>(op))
128  return gpu::MMAElementwiseOp::ADDF;
129  if (isa<arith::MulFOp>(op))
130  return gpu::MMAElementwiseOp::MULF;
131  if (isa<arith::MaxFOp>(op))
132  return gpu::MMAElementwiseOp::MAXF;
133  if (isa<arith::MinFOp>(op))
134  return gpu::MMAElementwiseOp::MINF;
135  if (isa<arith::DivFOp>(op))
136  return gpu::MMAElementwiseOp::DIVF;
137  return llvm::None;
138 }
139 
140 /// Return true if the op is supported as elementwise op on MMAMatrix type.
142  return convertElementwiseOpToMMA(op).hasValue();
143 }
144 
146  if (isa<scf::ForOp, scf::YieldOp>(op))
147  return true;
148  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
149  return transferReadSupportsMMAMatrixType(transferRead);
150  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
151  return transferWriteSupportsMMAMatrixType(transferWrite);
152  if (auto contract = dyn_cast<vector::ContractionOp>(op))
154  if (auto constant = dyn_cast<arith::ConstantOp>(op))
155  return constantSupportsMMAMatrixType(constant);
156  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
159 }
160 
161 /// Return an unsorted slice handling scf.for region differently than
162 /// `getSlice`. In scf.for we only want to include as part of the slice elements
163 /// that are part of the use/def chain.
165  TransitiveFilter backwardFilter,
166  TransitiveFilter forwardFilter) {
168  slice.insert(op);
169  unsigned currentIndex = 0;
170  SetVector<Operation *> backwardSlice;
171  SetVector<Operation *> forwardSlice;
172  while (currentIndex != slice.size()) {
173  auto *currentOp = (slice)[currentIndex];
174  // Compute and insert the backwardSlice starting from currentOp.
175  backwardSlice.clear();
176  getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
177  slice.insert(backwardSlice.begin(), backwardSlice.end());
178 
179  // Compute and insert the forwardSlice starting from currentOp.
180  forwardSlice.clear();
181  // Special case for ForOp, we don't want to include the whole region but
182  // only the value using the region arguments.
183  // TODO: We should refine this to only care about the region arguments being
184  // converted to matrix type.
185  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
186  for (Value forOpResult : forOp.getResults())
187  getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
188  for (BlockArgument &arg : forOp.getRegionIterArgs())
189  getForwardSlice(arg, &forwardSlice, forwardFilter);
190  } else {
191  getForwardSlice(currentOp, &forwardSlice, forwardFilter);
192  }
193  slice.insert(forwardSlice.begin(), forwardSlice.end());
194  ++currentIndex;
195  }
196  return slice;
197 }
198 
199 // Analyze slice of operations based on convert op to figure out if the whole
200 // slice can be converted to MMA operations.
202  auto hasVectorDest = [](Operation *op) {
203  return llvm::any_of(op->getResultTypes(),
204  [](Type t) { return t.isa<VectorType>(); });
205  };
206  auto hasVectorSrc = [](Operation *op) {
207  return llvm::any_of(op->getOperandTypes(),
208  [](Type t) { return t.isa<VectorType>(); });
209  };
210  SetVector<Operation *> opToConvert;
211  op->walk([&](vector::ContractionOp contract) {
212  if (opToConvert.contains(contract.getOperation()))
213  return;
214  SetVector<Operation *> dependentOps =
215  getSliceContract(contract, hasVectorDest, hasVectorSrc);
216  // If any instruction cannot use MMA matrix type drop the whole
217  // chain. MMA matrix are stored in an opaque type so they cannot be used
218  // by all operations.
219  if (llvm::any_of(dependentOps,
220  [](Operation *op) { return !supportsMMaMatrixType(op); }))
221  return;
222  opToConvert.insert(dependentOps.begin(), dependentOps.end());
223  });
224  // Sort the operations so that we can convert them in topological order.
225  return topologicalSort(opToConvert);
226 }
227 
228 namespace {
229 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
230 // to MMA matmul.
231 struct PrepareContractToGPUMMA
232  : public OpRewritePattern<vector::ContractionOp> {
234 
235  LogicalResult matchAndRewrite(vector::ContractionOp op,
236  PatternRewriter &rewriter) const override {
237  Location loc = op.getLoc();
238  Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
239 
240  // Set up the parallel/reduction structure in right form.
241  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
242  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
243  AffineExpr m, n, k;
244  bindDims(rewriter.getContext(), m, n, k);
245  static constexpr std::array<int64_t, 2> perm = {1, 0};
246  auto iteratorTypes = op.iterator_types().getValue();
247  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
248  if (!(isParallelIterator(iteratorTypes[0]) &&
249  isParallelIterator(iteratorTypes[1]) &&
250  isReductionIterator(iteratorTypes[2])))
251  return failure();
252  //
253  // Two outer parallel, one inner reduction (matmat flavor).
254  //
255  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
256  // This is the classical row-major matmul, nothing to do.
257  return failure();
258  }
259  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
260  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
261  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
262  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
263  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
264  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
265  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
266  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
267  std::swap(rhs, lhs);
268  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
269  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
270  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
271  std::swap(rhs, lhs);
272  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
273  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
274  std::swap(lhs, rhs);
275  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
276  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
277  std::swap(lhs, rhs);
278  } else {
279  return failure();
280  }
281  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
282  op, lhs, rhs, res,
283  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
284  op.iterator_types());
285  return success();
286  }
287 };
288 
289 // Merge transpose op into the transfer read op. Transpose are not supported on
290 // MMA types but MMA load can transpose the matrix when loading.
291 struct CombineTransferReadOpTranspose final
292  : public OpRewritePattern<vector::TransposeOp> {
294 
295  LogicalResult matchAndRewrite(vector::TransposeOp op,
296  PatternRewriter &rewriter) const override {
297  auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
298  if (!transferReadOp)
299  return failure();
300 
301  // TODO: support 0-d corner case.
302  if (transferReadOp.getTransferRank() == 0)
303  return failure();
304 
305  if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
306  return failure();
308  op.getTransp(perm);
310  for (int64_t o : perm)
311  permU.push_back(unsigned(o));
312  AffineMap permutationMap =
313  AffineMap::getPermutationMap(permU, op.getContext());
314  AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
315  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
316  op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
317  AffineMapAttr::get(newMap), transferReadOp.padding(),
318  transferReadOp.mask(), transferReadOp.in_boundsAttr());
319  return success();
320  }
321 };
322 
323 } // namespace
324 
325 // MMA types have different layout based on how they are used in matmul ops.
326 // Figure the right layout to use by looking at op uses.
327 // TODO: Change the GPU dialect to abstract the layout at the this level and
328 // only care about it during lowering to NVVM.
329 template <typename OpTy>
330 static const char *inferFragType(OpTy op) {
331  for (Operation *users : op->getUsers()) {
332  auto contract = dyn_cast<vector::ContractionOp>(users);
333  if (!contract)
334  continue;
335  if (contract.lhs() == op.getResult())
336  return "AOp";
337  if (contract.rhs() == op.getResult())
338  return "BOp";
339  }
340  return "COp";
341 }
342 
343 static void convertTransferReadOp(vector::TransferReadOp op,
344  llvm::DenseMap<Value, Value> &valueMapping) {
345  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
347  Optional<int64_t> stride =
348  getMemrefConstantHorizontalStride(op.getShapedType());
349  AffineMap map = op.permutation_map();
350  // Handle broadcast by setting the stride to 0.
351  if (map.getResult(0).isa<AffineConstantExpr>()) {
352  assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
353  stride = 0;
354  }
355  assert(stride);
356  const char *fragType = inferFragType(op);
357  gpu::MMAMatrixType type =
358  gpu::MMAMatrixType::get(op.getVectorType().getShape(),
359  op.getVectorType().getElementType(), fragType);
360  OpBuilder b(op);
361  Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
362  op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride));
363  valueMapping[op.getResult()] = load;
364 }
365 
366 static void convertTransferWriteOp(vector::TransferWriteOp op,
367  llvm::DenseMap<Value, Value> &valueMapping) {
369  Optional<int64_t> stride =
370  getMemrefConstantHorizontalStride(op.getShapedType());
371  assert(stride);
372  OpBuilder b(op);
373  Value matrix = valueMapping.find(op.vector())->second;
374  b.create<gpu::SubgroupMmaStoreMatrixOp>(
375  op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride));
376  op.erase();
377 }
378 
379 static void convertContractOp(vector::ContractionOp op,
380  llvm::DenseMap<Value, Value> &valueMapping) {
381  OpBuilder b(op);
382  Value opA = valueMapping.find(op.lhs())->second;
383  Value opB = valueMapping.find(op.rhs())->second;
384  Value opC = valueMapping.find(op.acc())->second;
385  Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
386  opA, opB, opC);
387  valueMapping[op.getResult()] = matmul;
388 }
389 
390 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
391 static void convertConstantOp(arith::ConstantOp op,
392  llvm::DenseMap<Value, Value> &valueMapping) {
393  assert(constantSupportsMMAMatrixType(op));
394  OpBuilder b(op);
395  Attribute splat =
396  op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>();
397  auto scalarConstant =
398  b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
399  const char *fragType = inferFragType(op);
400  auto vecType = op.getType().cast<VectorType>();
402  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
403  auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
404  scalarConstant);
405  valueMapping[op.getResult()] = matrix;
406 }
407 
408 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
409 static void convertBroadcastOp(vector::BroadcastOp op,
410  llvm::DenseMap<Value, Value> &valueMapping) {
411  assert(broadcastSupportsMMAMatrixType(op));
412  OpBuilder b(op);
413  const char *fragType = inferFragType(op);
414  auto vecType = op.getVectorType();
416  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
417  auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
418  op.source());
419  valueMapping[op.getResult()] = matrix;
420 }
421 
422 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
423 // updated and needs to be updated separatly for the loop to be correct.
424 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
425  ValueRange newIterOperands) {
426  // Create a new loop before the existing one, with the extra operands.
428  b.setInsertionPoint(loop);
429  auto operands = llvm::to_vector<4>(loop.getIterOperands());
430  operands.append(newIterOperands.begin(), newIterOperands.end());
431  scf::ForOp newLoop =
432  b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(),
433  loop.getUpperBound(), loop.getStep(), operands);
434  newLoop.getBody()->erase();
435  newLoop.getLoopBody().getBlocks().splice(
436  newLoop.getLoopBody().getBlocks().begin(),
437  loop.getLoopBody().getBlocks());
438  for (Value operand : newIterOperands)
439  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
440 
441  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
442  loop.getNumResults())))
443  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
444  loop.erase();
445  return newLoop;
446 }
447 
448 static void convertForOp(scf::ForOp op,
449  llvm::DenseMap<Value, Value> &valueMapping) {
450  SmallVector<Value> newOperands;
452  for (const auto &operand : llvm::enumerate(op.getIterOperands())) {
453  auto it = valueMapping.find(operand.value());
454  if (it == valueMapping.end())
455  continue;
456  argMapping.push_back(std::make_pair(
457  operand.index(), op.getNumIterOperands() + newOperands.size()));
458  newOperands.push_back(it->second);
459  }
460  OpBuilder b(op);
461  scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands);
462  Block &loopBody = *newForOp.getBody();
463  for (auto mapping : argMapping) {
464  valueMapping[newForOp.getResult(mapping.first)] =
465  newForOp.getResult(mapping.second);
466  valueMapping[loopBody.getArgument(mapping.first +
467  newForOp.getNumInductionVars())] =
468  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
469  }
470 }
471 
472 static void convertYieldOp(scf::YieldOp op,
473  llvm::DenseMap<Value, Value> &valueMapping) {
474  OpBuilder b(op);
475  auto loop = cast<scf::ForOp>(op->getParentOp());
476  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
477  for (const auto &operand : llvm::enumerate(op.getOperands())) {
478  auto it = valueMapping.find(operand.value());
479  if (it == valueMapping.end())
480  continue;
481  // Replace the yield of old value with the for op argument to make it easier
482  // to remove the dead code.
483  yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
484  yieldOperands.push_back(it->second);
485  }
486  b.create<scf::YieldOp>(op.getLoc(), yieldOperands);
487  op.erase();
488 }
489 
490 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
491 static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
492  llvm::DenseMap<Value, Value> &valueMapping) {
493  OpBuilder b(op);
494  SmallVector<Value> matrixOperands;
495  for (Value operand : op->getOperands())
496  matrixOperands.push_back(valueMapping.find(operand)->second);
497  Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>(
498  op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
499  valueMapping[op->getResult(0)] = newOp;
500 }
501 
502 namespace mlir {
503 
505  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
506  patterns.getContext());
507 }
508 
509 void convertVectorToMMAOps(FuncOp funcOp) {
511  llvm::DenseMap<Value, Value> valueMapping;
512  for (Operation *op : ops) {
513  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
514  convertTransferReadOp(transferRead, valueMapping);
515  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
516  convertTransferWriteOp(transferWrite, valueMapping);
517  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
518  convertContractOp(contractOp, valueMapping);
519  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
520  convertConstantOp(constantOp, valueMapping);
521  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
522  convertBroadcastOp(broadcastOp, valueMapping);
523  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
524  convertForOp(forOp, valueMapping);
525  } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
526  convertYieldOp(yiledOp, valueMapping);
527  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
528  convertElementwiseOp(op, *elementwiseType, valueMapping);
529  }
530  }
531 }
532 
533 } // namespace mlir
534 namespace {
535 
536 struct ConvertVectorToGPUPass
537  : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
538  void runOnOperation() override {
539  RewritePatternSet patterns(getOperation().getContext());
541  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
542 
543  convertVectorToMMAOps(getOperation());
544  }
545 };
546 
547 } // namespace
548 
549 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
550  return std::make_unique<ConvertVectorToGPUPass>();
551 }
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
Definition: VectorToGPU.cpp:92
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
void convertVectorToMMAOps(FuncOp funcOp)
Convert vector ops to MMA matrix operations.
MLIRContext * getContext() const
Definition: Builders.h:54
U cast() const
Definition: Attributes.h:123
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:444
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
bool isParallelIterator(Attribute attr)
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations...
Definition: GPUDialect.h:125
unsigned getNumDims() const
Definition: AffineMap.cpp:294
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
static void convertTransferWriteOp(vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
Definition: VectorToGPU.cpp:73
Block represents an ordered list of Operations.
Definition: Block.h:29
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
void getBackwardSlice(Operation *op, SetVector< Operation *> *backwardSlice, TransitiveFilter filter=nullptr)
Fills backwardSlice with the computed backward slice (i.e.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:205
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, ValueRange newIterOperands)
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
operand_type_range getOperandTypes()
Definition: Operation.h:266
static llvm::Optional< int64_t > getMemrefConstantHorizontalStride(ShapedType type)
Definition: VectorToGPU.cpp:59
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
static SetVector< Operation * > getSliceContract(Operation *op, TransitiveFilter backwardFilter, TransitiveFilter forwardFilter)
Return an unsorted slice handling scf.for region differently than getSlice.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:109
static void convertBroadcastOp(vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static bool supportsMMaMatrixType(Operation *op)
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:258
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::enable_if< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT >::type walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one)...
Definition: Operation.h:515
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
std::unique_ptr< Pass > createConvertVectorToGPUPass()
Convert from vector to GPU ops.
static const char * inferFragType(OpTy op)
SetVector< Operation * > topologicalSort(const SetVector< Operation *> &toSort)
Multi-root DAG topological sort.
static void convertForOp(scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
void getForwardSlice(Operation *op, SetVector< Operation *> *forwardSlice, TransitiveFilter filter=nullptr)
Fills forwardSlice with the computed forward slice (i.e.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:39
Base type for affine expression.
Definition: AffineExpr.h:68
MLIRContext * getContext() const
Definition: AffineExpr.cpp:23
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
This class represents an argument of a Block.
Definition: Value.h:298
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
bool isReductionIterator(Attribute attr)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract)
Definition: VectorToGPU.cpp:34
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void convertContractOp(vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getOpToConvert(mlir::Operation *op)
static llvm::Optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:591
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
result_type_range getResultTypes()
Definition: Operation.h:297
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
MLIRContext * getContext() const
Definition: PatternMatch.h:906
static void convertYieldOp(scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.