MLIR  22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
22 
23 namespace mlir {
24 namespace linalg {
25 namespace {
26 /// Verify that the runtime sizes of the operands to linalg structured ops are
27 /// compatible with the runtime sizes inferred by composing the loop ranges with
28 /// the linalg op's indexing maps. This is similar to the verifier except that
29 /// here we insert IR to perform the verification at runtime.
30 template <typename T>
31 struct StructuredOpInterface
32  : public RuntimeVerifiableOpInterface::ExternalModel<
33  StructuredOpInterface<T>, T> {
34  void
35  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
36  function_ref<std::string(Operation *, StringRef)>
37  generateErrorMessage) const {
38  auto linalgOp = llvm::cast<LinalgOp>(op);
39 
40  SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
41  auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
42 
43  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
44  auto one = arith::ConstantIndexOp::create(builder, loc, 1);
45 
46  // Subtract one from the loop ends before composing with the indexing map
47  transform(ends, ends.begin(), [&](OpFoldResult end) {
48  auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
49  return builder.createOrFold<index::SubOp>(loc, endValue, one);
50  });
51 
52  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
53  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
55  builder, loc, indexingMap, starts);
57  builder, loc, indexingMap, ends);
58 
59  for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
60  auto startIndex =
61  getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
62  auto endIndex =
63  getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
64 
65  // Generate:
66  // minIndex = min(startIndex, endIndex)
67  // assert(minIndex >= 0)
68  // To ensure we do not generate a negative index. We take the minimum of
69  // the start and end indices in order to handle reverse loops such as
70  // `affine_map<(i) -> (3 - i)>`
71  auto min =
72  builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
73  auto cmpOp = builder.createOrFold<index::CmpOp>(
74  loc, index::IndexCmpPredicate::SGE, min, zero);
75  auto msg = generateErrorMessage(
76  linalgOp, "unexpected negative result on dimension #" +
77  std::to_string(dim) + " of input/output operand #" +
78  std::to_string(opOperand.getOperandNumber()));
79  builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
80 
81  // Generate:
82  // inferredDimSize = max(startIndex, endIndex) + 1
83  // actualDimSize = dim(operand)
84  // assert(inferredDimSize <= actualDimSize)
85  // To ensure that we do not index past the bounds of the operands.
86  auto max =
87  builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
88 
89  auto inferredDimSize =
90  builder.createOrFold<index::AddOp>(loc, max, one);
91 
92  auto actualDimSize =
93  createOrFoldDimOp(builder, loc, opOperand.get(), dim);
94 
95  // Similar to the verifier, when the affine expression in the indexing
96  // map is complicated, we just check that the inferred dimension sizes
97  // are in the boundary of the operands' size. Being more precise than
98  // that is difficult.
99  auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
100  ? index::IndexCmpPredicate::EQ
101  : index::IndexCmpPredicate::SLE;
102 
103  cmpOp = builder.createOrFold<index::CmpOp>(
104  loc, predicate, inferredDimSize, actualDimSize);
105  msg = generateErrorMessage(
106  linalgOp, "dimension #" + std::to_string(dim) +
107  " of input/output operand #" +
108  std::to_string(opOperand.getOperandNumber()) +
109  " is incompatible with inferred dimension size");
110  builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
111  }
112  }
113  }
114 };
115 
116 template <typename... OpTs>
117 void attachInterface(MLIRContext *ctx) {
118  (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
119 }
120 } // namespace
121 } // namespace linalg
122 } // namespace mlir
123 
125  DialectRegistry &registry) {
126  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
127  attachInterface<
128 #define GET_OP_LIST
129 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
130  >(ctx);
131 
132  // Load additional dialects of which ops may get created.
133  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
134  cf::ControlFlowDialect, index::IndexDialect,
135  tensor::TensorDialect>();
136  });
137 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:110
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1374
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:95
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111