MLIR  22.0.0git
RuntimeOpVerification.cpp
Go to the documentation of this file.
1 //===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
23 
24 namespace mlir {
25 namespace linalg {
26 namespace {
27 /// Verify that the runtime sizes of the operands to linalg structured ops are
28 /// compatible with the runtime sizes inferred by composing the loop ranges with
29 /// the linalg op's indexing maps. This is similar to the verifier except that
30 /// here we insert IR to perform the verification at runtime.
31 template <typename T>
32 struct StructuredOpInterface
33  : public RuntimeVerifiableOpInterface::ExternalModel<
34  StructuredOpInterface<T>, T> {
35  void
36  generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
37  function_ref<std::string(Operation *, StringRef)>
38  generateErrorMessage) const {
39  auto linalgOp = llvm::cast<LinalgOp>(op);
40 
41  SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
42  auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
43 
44  auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
45  auto one = arith::ConstantIndexOp::create(builder, loc, 1);
46 
47  Value iterationDomainIsNonDegenerate;
48  for (auto [start, end] : llvm::zip(starts, ends)) {
49  auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
50  auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
51 
52  // Loop Trip count > 0 iff start < end
53  Value dimensionHasNonZeroTripCount = index::CmpOp::create(
54  builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue);
55 
56  if (!iterationDomainIsNonDegenerate) {
57  iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
58  } else {
59  // Iteration domain is non-degenerate iff all dimensions have loop trip
60  // count > 0
61  iterationDomainIsNonDegenerate =
62  arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate,
63  dimensionHasNonZeroTripCount);
64  }
65  }
66 
67  if (!iterationDomainIsNonDegenerate)
68  return;
69 
70  auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate,
71  /*withElseRegion=*/false);
72  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
73 
74  // Subtract one from the loop ends before composing with the indexing map
75  transform(ends, ends.begin(), [&](OpFoldResult end) {
76  auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
77  return builder.createOrFold<index::SubOp>(loc, endValue, one);
78  });
79 
80  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
81  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
83  builder, loc, indexingMap, starts);
85  builder, loc, indexingMap, ends);
86 
87  for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
88  auto startIndex =
89  getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
90  auto endIndex =
91  getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
92 
93  // Generate:
94  // minIndex = min(startIndex, endIndex)
95  // assert(minIndex >= 0)
96  // To ensure we do not generate a negative index. We take the minimum of
97  // the start and end indices in order to handle reverse loops such as
98  // `affine_map<(i) -> (3 - i)>`
99  auto min =
100  builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
101  auto cmpOp = builder.createOrFold<index::CmpOp>(
102  loc, index::IndexCmpPredicate::SGE, min, zero);
103  auto msg = generateErrorMessage(
104  linalgOp, "unexpected negative result on dimension #" +
105  std::to_string(dim) + " of input/output operand #" +
106  std::to_string(opOperand.getOperandNumber()));
107  builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
108 
109  // Generate:
110  // inferredDimSize = max(startIndex, endIndex) + 1
111  // actualDimSize = dim(operand)
112  // assert(inferredDimSize <= actualDimSize)
113  // To ensure that we do not index past the bounds of the operands.
114  auto max =
115  builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
116 
117  auto inferredDimSize =
118  builder.createOrFold<index::AddOp>(loc, max, one);
119 
120  auto actualDimSize =
121  createOrFoldDimOp(builder, loc, opOperand.get(), dim);
122 
123  // Similar to the verifier, when the affine expression in the indexing
124  // map is complicated, we just check that the inferred dimension sizes
125  // are in the boundary of the operands' size. Being more precise than
126  // that is difficult.
127  auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
128  ? index::IndexCmpPredicate::EQ
129  : index::IndexCmpPredicate::SLE;
130 
131  cmpOp = builder.createOrFold<index::CmpOp>(
132  loc, predicate, inferredDimSize, actualDimSize);
133  msg = generateErrorMessage(
134  linalgOp, "dimension #" + std::to_string(dim) +
135  " of input/output operand #" +
136  std::to_string(opOperand.getOperandNumber()) +
137  " is incompatible with inferred dimension size");
138  builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
139  }
140  }
141  builder.setInsertionPointAfter(ifOp);
142  }
143 };
144 
145 template <typename... OpTs>
146 void attachInterface(MLIRContext *ctx) {
147  (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
148 }
149 } // namespace
150 } // namespace linalg
151 } // namespace mlir
152 
154  DialectRegistry &registry) {
155  registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
156  attachInterface<
157 #define GET_OP_LIST
158 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
159  >(ctx);
160 
161  // Load additional dialects of which ops may get created.
162  ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
163  cf::ControlFlowDialect, index::IndexDialect,
164  tensor::TensorDialect>();
165  });
166 }
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:110
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1514
void registerRuntimeVerifiableOpInterfaceExternalModels(DialectRegistry &registry)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:95
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111