MLIR  14.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Parser.h"
25 
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/StringSet.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "llvm/Support/MathExtras.h"
33 #include "llvm/Support/raw_ostream.h"
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"
39 
40 /// Forward declarations.
41 
42 /// Generic entry point to create the block for the region of a LinalgOp.
43 /// This is used by both named structured ops created by ods-gen and by manually
44 /// defined C++ ops.
45 /// This is used by both builders and parsers.
46 /// This function creates the block in the region with arguments corresponding
47 /// to the elemental types of `inputTypes` and `outputTypes`. The latter are
48 /// asserted to be of ShapedType.
49 template <typename NamedStructuredOpType>
50 static void fillStructuredOpRegion(
51  OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
52  TypeRange outputTypes,
53  llvm::function_ref<void(unsigned, unsigned)> errorHandler = nullptr);
54 
55 /// Generic entry point to create both the region and the block of a LinalgOp.
56 template <typename NamedStructuredOpType>
57 static void
59  TypeRange inputTypes, TypeRange outputTypes);
60 
61 /// Common parsing and printing used for both named structured ops created by
62 /// ods-gen and by manually defined C++ ops. Does not handle regions.
63 static ParseResult
65  SmallVectorImpl<Type> &inputTypes,
66  SmallVectorImpl<Type> &outputTypes);
67 template <typename NamedStructuredOpType>
69  NamedStructuredOpType op);
70 
71 /// Specific parsing and printing for named structured ops created by ods-gen.
72 template <typename NamedStructuredOpType>
73 static ParseResult
75  TypeRange inputTypes, TypeRange outputTypes);
76 
77 static ParseResult
79  SmallVectorImpl<Type> &resultTypes);
80 
81 template <typename NamedStructuredOpType>
83  OperationState &result);
84 
86  TypeRange resultTypes);
87 
88 template <typename NamedStructuredOpType>
89 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
90 
91 /// This is a common class used for patterns of the form
92 /// ```
93 /// someop(memrefcast(%src)) -> someop(%src)
94 /// ```
95 /// It folds the source of the memref.cast into the root operation directly.
97  bool folded = false;
98  for (OpOperand &operand : op->getOpOperands()) {
99  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
100  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
101  operand.set(castOp.getOperand());
102  folded = true;
103  }
104  }
105  return success(folded);
106 }
107 
108 /// This is a specialization of `foldMemRefCast` used for patterns of the form
109 /// ```
110 /// tiled_loop(memrefcast(%src)) -> tiled_loop(%src)
111 /// ```
112 /// It folds the source of the memref.cast into the root operation directly.
114  bool folded = false;
115  Location loc = op->getLoc();
116 
117  Block *body = op.getBody();
119 
120  // Update `input` and `output` operands and block arguments if necessary.
121  // Operands list: [lbs, ubs, steps, inputs, outputs].
122  // Block args list: [ivs, inputs, outputs].
123  for (size_t operandIndex = op.getNumControlOperands(),
124  bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
125  operandIndex < e; ++operandIndex, ++bbArgIndex) {
126  OpOperand &operand = op->getOpOperand(operandIndex);
127 
128  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
129  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
130  operand.set(castOp.getOperand());
131  BlockArgument newBbArg = body->insertArgument(
132  bbArgIndex, castOp.getOperand().getType(), op.getLoc());
133  BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
134 
135  // Insert memref.cast back to the original type.
136  oldBbArg.replaceAllUsesWith(
137  b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
138  body->eraseArgument(oldBbArg.getArgNumber());
139 
140  folded = true;
141  }
142  }
143  return success(folded);
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // Region builder helper.
148 // TODO: Move this to a utility library.
149 // The public methods on this class are referenced directly from generated code
150 // and bind by name to math and type conversion functions in the DSL as:
151 // `arithfn__{fnName}`
152 // `typefn__{fnName}`
153 // Examples:
154 // `arithfn__add`
155 // `arithfn__mul`
156 // `typefn__cast`
157 // The naming convention is intentional in order to match snake-cased DSL names.
158 // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
159 //
160 // Implementations of the math functions must be polymorphic over numeric types,
161 // internally performing necessary casts. If the function application makes no
162 // sense, then the only recourse is to assert and return nullptr. This can be
163 // extended later if it becomes possible to fail construction of the region. The
164 // invariant should be enforced at a higher level.
165 //
166 // TODO: These helpers are currently type polymorphic over the class of integer
167 // and floating point types, but they will not internally cast within bit
168 // widths of a class (mixed precision such as i8->i32) or across classes
169 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
170 // to be handled with care and work is being considered to extend the op
171 // language to make such cases explicit. In the mean-time, violating this will
172 // fail verification, which is deemed acceptable.
173 //===----------------------------------------------------------------------===//
174 
175 namespace {
176 
177 class RegionBuilderHelper {
178 public:
179  RegionBuilderHelper(MLIRContext *context, Block &block)
180  : context(context), block(block) {}
181 
182  // Generates operations to cast the given operand to a specified type.
183  // If the cast cannot be performed, a warning will be issued and the
184  // operand returned as-is (which will presumably yield a verification
185  // issue downstream).
186  Value cast(Type toType, Value operand, bool isUnsignedCast) {
187  OpBuilder builder = getBuilder();
188  auto loc = operand.getLoc();
189 
190  if (operand.getType() == toType)
191  return operand;
192  if (auto toIntType = toType.dyn_cast<IntegerType>()) {
193  // If operand is floating point, cast directly to the int type.
194  if (operand.getType().isa<FloatType>()) {
195  if (isUnsignedCast)
196  return builder.create<arith::FPToUIOp>(loc, toType, operand);
197  return builder.create<arith::FPToSIOp>(loc, toType, operand);
198  }
199  // Cast index operands directly to the int type.
200  if (operand.getType().isIndex())
201  return builder.create<arith::IndexCastOp>(loc, toType, operand);
202  if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
203  // Either extend or truncate.
204  if (toIntType.getWidth() > fromIntType.getWidth()) {
205  if (isUnsignedCast)
206  return builder.create<arith::ExtUIOp>(loc, toType, operand);
207  return builder.create<arith::ExtSIOp>(loc, toType, operand);
208  }
209  if (toIntType.getWidth() < fromIntType.getWidth())
210  return builder.create<arith::TruncIOp>(loc, toType, operand);
211  }
212  } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
213  // If operand is integer, cast directly to the float type.
214  // Note that it is unclear how to cast from BF16<->FP16.
215  if (operand.getType().isa<IntegerType>()) {
216  if (isUnsignedCast)
217  return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
218  return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
219  }
220  if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
221  if (toFloatType.getWidth() > fromFloatType.getWidth())
222  return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
223  if (toFloatType.getWidth() < fromFloatType.getWidth())
224  return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
225  }
226  }
227 
228  emitWarning(operand.getLoc()) << "could not cast operand of type "
229  << operand.getType() << " to " << toType;
230  return operand;
231  }
232 
233  // NOLINTNEXTLINE(*-identifier-naming): externally called.
234  Value typefn__cast(Type toType, Value operand) {
235  return cast(toType, operand, false);
236  }
237 
238  // NOLINTNEXTLINE(*-identifier-naming): externally called.
239  Value typefn__cast_unsigned(Type toType, Value operand) {
240  return cast(toType, operand, true);
241  }
242 
243  // NOLINTNEXTLINE(*-identifier-naming): externally called.
244  Value arithfn__add(Value lhs, Value rhs) {
245  OpBuilder builder = getBuilder();
246  if (isFloatingPoint(lhs))
247  return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
248  if (isInteger(lhs))
249  return builder.create<arith::AddIOp>(lhs.getLoc(), lhs, rhs);
250  llvm_unreachable("unsupported non numeric type");
251  }
252 
253  // NOLINTNEXTLINE(*-identifier-naming): externally called.
254  Value arithfn__exp(Value x) {
255  OpBuilder builder = getBuilder();
256  if (isFloatingPoint(x))
257  return builder.create<math::ExpOp>(x.getLoc(), x);
258  llvm_unreachable("unsupported non numeric type");
259  }
260 
261  // NOLINTNEXTLINE(*-identifier-naming): externally called.
262  Value arithfn__log(Value x) {
263  OpBuilder builder = getBuilder();
264  if (isFloatingPoint(x))
265  return builder.create<math::LogOp>(x.getLoc(), x);
266  llvm_unreachable("unsupported non numeric type");
267  }
268 
269  // NOLINTNEXTLINE(*-identifier-naming): externally called.
270  Value arithfn__sub(Value lhs, Value rhs) {
271  OpBuilder builder = getBuilder();
272  if (isFloatingPoint(lhs))
273  return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
274  if (isInteger(lhs))
275  return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
276  llvm_unreachable("unsupported non numeric type");
277  }
278 
279  // NOLINTNEXTLINE(*-identifier-naming): externally called.
280  Value arithfn__mul(Value lhs, Value rhs) {
281  OpBuilder builder = getBuilder();
282  if (isFloatingPoint(lhs))
283  return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
284  if (isInteger(lhs))
285  return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
286  llvm_unreachable("unsupported non numeric type");
287  }
288 
289  // NOLINTNEXTLINE(*-identifier-naming): externally called.
290  Value arithfn__max(Value lhs, Value rhs) {
291  OpBuilder builder = getBuilder();
292  if (isFloatingPoint(lhs))
293  return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
294  if (isInteger(lhs))
295  return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
296  llvm_unreachable("unsupported non numeric type");
297  }
298 
299  // NOLINTNEXTLINE(*-identifier-naming): externally called.
300  Value arithfn__max_unsigned(Value lhs, Value rhs) {
301  OpBuilder builder = getBuilder();
302  if (isFloatingPoint(lhs))
303  return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
304  if (isInteger(lhs))
305  return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
306  llvm_unreachable("unsupported non numeric type");
307  }
308 
309  // NOLINTNEXTLINE(*-identifier-naming): externally called.
310  Value arithfn__min(Value lhs, Value rhs) {
311  OpBuilder builder = getBuilder();
312  if (isFloatingPoint(lhs))
313  return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
314  if (isInteger(lhs))
315  return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
316  llvm_unreachable("unsupported non numeric type");
317  }
318 
319  // NOLINTNEXTLINE(*-identifier-naming): externally called.
320  Value arithfn__min_unsigned(Value lhs, Value rhs) {
321  OpBuilder builder = getBuilder();
322  if (isFloatingPoint(lhs))
323  return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
324  if (isInteger(lhs))
325  return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
326  llvm_unreachable("unsupported non numeric type");
327  }
328 
329  void yieldOutputs(ValueRange values) {
330  assert(!values.empty() && "linalg ops must yield outputs");
331  if (values.empty())
332  return;
333  Value first = values.front();
334  OpBuilder builder = getBuilder();
335  builder.create<YieldOp>(first.getLoc(), values);
336  }
337 
338  Value constant(const std::string &value) {
339  OpBuilder builder = getBuilder();
340  Location loc = builder.getUnknownLoc();
341  Attribute valueAttr = parseAttribute(value, builder.getContext());
342  return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
343  valueAttr);
344  }
345 
346  Value index(int64_t dim) {
347  OpBuilder builder = getBuilder();
348  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
349  }
350 
351  Type getIntegerType(unsigned width) {
352  return IntegerType::get(context, width);
353  }
354 
355  Type getFloat32Type() { return Float32Type::get(context); }
356 
357  Type getFloat64Type() { return Float64Type::get(context); }
358 
359 private:
360  MLIRContext *context;
361  Block &block;
362 
363  bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
364  bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
365 
366  OpBuilder getBuilder() {
367  OpBuilder builder(context);
368  builder.setInsertionPointToEnd(&block);
369  return builder;
370  }
371 };
372 
373 } // namespace
374 
375 //===----------------------------------------------------------------------===//
376 // CopyOp
377 //===----------------------------------------------------------------------===//
378 void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
379  assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
380  b.create<linalg::YieldOp>(block.getArgument(0));
381 }
382 
383 void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
384  Value output, AffineMap inputPermutation,
385  AffineMap outputPermutation,
386  ArrayRef<NamedAttribute> namedAttrs) {
387  result.addOperands({input, output});
388  result.addAttributes(namedAttrs);
389  if (inputPermutation)
390  result.addAttribute("inputPermutation",
391  AffineMapAttr::get(inputPermutation));
392  if (outputPermutation)
393  result.addAttribute("outputPermutation",
394  AffineMapAttr::get(outputPermutation));
395  result.addRegion();
396  fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
397  TypeRange{input.getType()},
398  TypeRange{output.getType()});
399 }
400 
402  Type outputType) {
403  OpBuilder opBuilder(parser.getContext());
404  fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
405  TypeRange{outputType});
406  return success();
407 }
408 
409 /// CopyOp region is elided when printing.
411 
412 static LogicalResult verify(CopyOp op) {
413  OpOperand *output = op.getOutputOperand(0);
414  OpOperand *input = op.getInputOperand(0);
415  if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get()))
416  return op.emitOpError("expects views of the same type");
417  if (op.getRank(input) != op.getRank(output))
418  return op.emitOpError("expects views of the same rank");
419  auto rank = op.getNumParallelLoops();
420  auto inputPermutationMap = op.inputPermutation();
421  if (inputPermutationMap) {
422  if (inputPermutationMap->getNumInputs() != rank)
423  return op.emitOpError("expects optional input_permutation map of rank ")
424  << rank;
425  if (!inputPermutationMap->isPermutation())
426  return op.emitOpError(
427  "expects optional input_permutation map to be a permutation");
428  }
429  auto outputPermutationMap = op.outputPermutation();
430  if (outputPermutationMap) {
431  if (outputPermutationMap->getNumInputs() != rank)
432  return op.emitOpError("expects optional output_permutation map of rank ")
433  << rank;
434  if (!outputPermutationMap->isPermutation())
435  return op.emitOpError(
436  "expects optional output_permutation map to be a permutation");
437  }
438  if (rank == 0 && inputPermutationMap)
439  return op.emitOpError("expected no input permutation when rank == 0");
440  if (rank == 0 && outputPermutationMap)
441  return op.emitOpError("expected no output permutation when rank == 0");
442  return success();
443 }
444 
445 void CopyOp::getEffects(
447  &effects) {
448  effects.emplace_back(MemoryEffects::Read::get(), input(),
450  effects.emplace_back(MemoryEffects::Write::get(), output(),
452 }
453 
454 namespace {
455 /// Remove copy operations that copy data inplace. Requirements are:
456 /// 1) The input and output values are identical.
457 /// 2) The input and output permutation maps are identical.
458 struct EraseIdentityCopyOp : public OpRewritePattern<CopyOp> {
460 
461  LogicalResult matchAndRewrite(CopyOp copyOp,
462  PatternRewriter &rewriter) const override {
463  assert(copyOp.hasBufferSemantics());
464  if (copyOp.input() == copyOp.output() &&
465  copyOp.inputPermutation() == copyOp.outputPermutation()) {
466  rewriter.eraseOp(copyOp);
467  return success();
468  }
469  return failure();
470  }
471 };
472 } // namespace
473 
474 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
475  MLIRContext *context) {
476  results.add<EraseIdentityCopyOp>(context);
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // FillOp
481 //===----------------------------------------------------------------------===//
482 void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
483  assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
484  b.create<linalg::YieldOp>(block.getArgument(0));
485 }
486 
487 void FillOp::build(OpBuilder &builder, OperationState &result, Value value,
488  Value output) {
489  build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
490  output);
491  fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
492  TypeRange{value.getType()},
493  TypeRange{output.getType()}, {});
494 }
495 
497  Type outputType) {
498  OpBuilder opBuilder(parser.getContext());
499  fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
500  TypeRange{outputType});
501  return success();
502 }
503 
504 /// FillOp region is elided when printing.
506 
507 static LogicalResult verify(FillOp op) {
508  OpOperand *output = op.getOutputOperand(0);
509  Type fillType = op.value().getType();
510  if (getElementTypeOrSelf(output->get()) != fillType)
511  return op.emitOpError("expects fill type to match view elemental type");
512  return success();
513 }
514 
515 void FillOp::getEffects(
517  &effects) {
518  if (output().getType().isa<MemRefType>())
519  effects.emplace_back(MemoryEffects::Write::get(), output(),
521 }
522 
523 namespace {
524 
525 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
526 ///
527 /// For such op chains, we can create new linalg.fill ops with the result
528 /// type of the tensor.expand/collapse_shape op.
529 template <typename TensorReshapeOp>
530 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
532  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
533  PatternRewriter &rewriter) const override {
534  auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>();
535  if (!oldFill)
536  return failure();
537 
538  Location loc = oldFill.getLoc();
539  auto newInit = rewriter.create<TensorReshapeOp>(
540  loc, reshapeOp.getResultType(), oldFill.output(),
541  reshapeOp.reassociation());
542  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, oldFill.value(), newInit);
543 
544  return success();
545  }
546 };
547 
548 } // namespace
549 
550 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
551  MLIRContext *context) {
552  results.add<FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
553  FoldFillWithTensorReshape<tensor::ExpandShapeOp>>(context);
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // GenericOps
558 //===----------------------------------------------------------------------===//
559 void GenericOp::build(
560  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
561  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
562  ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
563  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
564  ArrayRef<NamedAttribute> attributes) {
565  build(builder, result, resultTensorTypes, inputs, outputs,
566  builder.getAffineMapArrayAttr(indexingMaps),
567  builder.getStrArrayAttr(iteratorTypes),
568  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
569  libraryCall.empty() ? StringAttr()
570  : builder.getStringAttr(libraryCall));
571  result.addAttributes(attributes);
572  if (!bodyBuild)
573  return;
574 
575  SmallVector<Type, 4> blockArgTypes;
576  SmallVector<Location, 4> blockArgLocs;
577  for (ValueRange container : {inputs, outputs}) {
578  for (Value v : container) {
579  blockArgTypes.push_back(getElementTypeOrSelf(v));
580  blockArgLocs.push_back(v.getLoc());
581  }
582  }
583 
584  OpBuilder::InsertionGuard guard(builder);
585  auto &region = *result.regions.front();
586  Block *bodyBlock =
587  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
588  bodyBuild(builder, result.location, bodyBlock->getArguments());
589 }
590 
591 void GenericOp::build(
592  OpBuilder &builder, OperationState &result, ValueRange inputs,
593  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
594  ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
595  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
596  ArrayRef<NamedAttribute> attributes) {
597  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
598  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
599 }
600 
601 void GenericOp::build(
602  OpBuilder &builder, OperationState &result, ValueRange inputs,
603  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
604  ArrayRef<StringRef> iteratorTypes,
605  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
606  ArrayRef<NamedAttribute> attributes) {
607  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
608  /*doc=*/"",
609  /*libraryCall=*/"", bodyBuild, attributes);
610 }
611 
612 void GenericOp::build(
613  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
614  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
615  ArrayRef<StringRef> iteratorTypes,
616  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
617  ArrayRef<NamedAttribute> attributes) {
618  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
619  iteratorTypes,
620  /*doc=*/"",
621  /*libraryCall=*/"", bodyBuild, attributes);
622 }
623 
624 static void print(OpAsmPrinter &p, GenericOp op) {
625  p << " ";
626 
627  // Print extra attributes.
628  auto genericAttrNames = op.linalgTraitAttrNames();
629 
630  llvm::StringSet<> genericAttrNamesSet;
631  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
632  SmallVector<NamedAttribute, 8> genericAttrs;
633  for (auto attr : op->getAttrs())
634  if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
635  genericAttrs.push_back(attr);
636  if (!genericAttrs.empty()) {
637  auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
638  p << genericDictAttr;
639  }
640 
641  // Printing is shared with named ops, except for the region and attributes
643 
644  genericAttrNames.push_back("operand_segment_sizes");
645  genericAttrNamesSet.insert(genericAttrNames.back());
646 
647  bool hasExtraAttrs = false;
648  for (NamedAttribute n : op->getAttrs()) {
649  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
650  break;
651  }
652  if (hasExtraAttrs) {
653  p << " attrs = ";
654  p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames);
655  }
656 
657  // Print region.
658  if (!op.region().empty()) {
659  p << ' ';
660  p.printRegion(op.region());
661  }
662 
663  // Print results.
664  printNamedStructuredOpResults(p, op.result_tensors().getTypes());
665 }
666 
668  DictionaryAttr dictAttr;
669  // Parse the core linalg traits that must check into a dictAttr.
670  // The name is unimportant as we will overwrite result.attributes.
671  // The core linalg traits must contain the information necessary to pass the
672  // verifier.
673  if (parser.parseAttribute(dictAttr, "_", result.attributes))
674  return failure();
675  result.attributes.assign(dictAttr.getValue().begin(),
676  dictAttr.getValue().end());
677 
678  // Parsing is shared with named ops, except for the region.
679  SmallVector<Type, 1> inputTypes, outputTypes;
680  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
681  return failure();
682 
683  // Optional attributes may be added.
684  if (succeeded(parser.parseOptionalKeyword("attrs")))
685  if (failed(parser.parseEqual()) ||
686  failed(parser.parseOptionalAttrDict(result.attributes)))
687  return failure();
688 
690  std::unique_ptr<Region> region = std::make_unique<Region>();
691  SmallVector<Type, 8> operandTypes, regionTypes;
692  if (parser.parseRegion(*region, regionOperands, regionTypes))
693  return failure();
694  result.addRegion(std::move(region));
695 
696  // Generic ops may specify that a subset of its outputs are tensors. Such
697  // outputs are specified in the result type.
698  // TODO: may need to move output parsing before region parsing.
699  // Need to wait for declarative assembly resolution to decide.
700  SmallVector<Type, 1> outputTensorsTypes;
701  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
702  return failure();
703  result.addTypes(outputTensorsTypes);
704 
705  return success();
706 }
707 
710  &effects,
711  ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
712  for (Value value : results) {
713  effects.emplace_back(MemoryEffects::Allocate::get(), value,
715  }
716  for (Value value : inputBuffers) {
717  effects.emplace_back(MemoryEffects::Read::get(), value,
719  }
720  for (Value value : outputs) {
721  effects.emplace_back(MemoryEffects::Read::get(), value,
723  effects.emplace_back(MemoryEffects::Write::get(), value,
725  }
726 }
727 
728 void GenericOp::getEffects(
730  &effects) {
731  SmallVector<Value> inputBuffers = getInputBufferOperands();
732  SmallVector<Value> outputBuffers = getOutputBufferOperands();
733  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
734  outputBuffers);
735 }
736 
737 template <typename GenericOpType>
738 static LogicalResult verifyGenericOp(GenericOpType op) {
739  return success();
740 }
741 
742 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
743 
744 namespace {
745 // Deduplicate redundant args of a linalg generic op.
746 // An arg is redundant if it has the same Value and indexing map as another.
747 struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
749 
750  LogicalResult matchAndRewrite(GenericOp genericOp,
751  PatternRewriter &rewriter) const override {
752  // Associate each input to an equivalent "canonical" input that has the same
753  // Value and indexing map.
754  //
755  // In the non-duplicate case, input `i` will have canonical input `i`. But
756  // in the case of duplicated inputs, the canonical input could be some other
757  // input `< i`. That is, a later input will have some earlier input as its
758  // canonical input.
759  llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
760  // For later remapping tasks like deduplicating payload block arguments,
761  // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
762  // convenient.
763  SmallVector<unsigned> canonicalInputIndices;
764  for (OpOperand *opOperand : genericOp.getInputOperands()) {
765  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
766  // STL-like maps have a convenient behavior for our use case here. In the
767  // case of duplicate keys, the insertion is rejected, and the returned
768  // iterator gives access to the value already in the map.
769  auto pair = canonicalInput.insert(
770  {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
771  canonicalInputIndices.push_back(pair.first->second);
772  }
773 
774  // If there are no duplicate args, then bail out.
775  if (canonicalInput.size() == genericOp.getNumInputs())
776  return failure();
777 
778  // The operands for the newly canonicalized op.
779  SmallVector<Value> newInputOperands;
780  for (OpOperand *opOperand : genericOp.getInputOperands())
781  if (canonicalInputIndices[opOperand->getOperandNumber()] ==
782  opOperand->getOperandNumber())
783  newInputOperands.push_back(opOperand->get());
784 
785  // Repair the indexing maps by filtering out the ones that have been
786  // eliminated.
787  SmallVector<AffineMap> newIndexingMaps;
788  for (OpOperand *opOperand : genericOp.getInputOperands())
789  if (canonicalInputIndices[opOperand->getOperandNumber()] ==
790  opOperand->getOperandNumber())
791  newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
792  for (OpOperand *opOperand : genericOp.getOutputOperands())
793  newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
794 
795  // Clone the old op with new operands.
796  SmallVector<Value> outputOperands = genericOp.getOutputOperands();
797  auto newOp = rewriter.create<GenericOp>(
798  genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
799  outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
800  genericOp.iterator_types(), genericOp.docAttr(),
801  genericOp.library_callAttr());
802 
803  // Copy over unknown attributes. They might be load bearing for some flow.
804  ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
805  for (NamedAttribute kv : genericOp->getAttrs()) {
806  if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
807  newOp->setAttr(kv.getName(), kv.getValue());
808  }
809  }
810 
811  rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
812  newOp.region().begin());
813 
814  // Repair the payload entry block by RAUW'ing redundant arguments and
815  // erasing them.
816  Block &payload = newOp.region().front();
817  SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
818  for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
819  // Iterate in reverse, so that we erase later args first, preventing the
820  // argument list from shifting unexpectedly and invalidating all our
821  // indices.
822  unsigned operandNumber = opOperand->getOperandNumber();
823  if (canonicalInputIndices[operandNumber] == operandNumber)
824  continue;
825  payload.getArgument(operandNumber)
827  payload.getArgument(canonicalInputIndices[operandNumber]));
828  payload.eraseArgument(operandNumber);
829  }
830 
831  rewriter.replaceOp(genericOp, newOp->getResults());
832  return success();
833  }
834 };
835 
836 /// Remove generic operations (on tensors) that are just copying
837 /// the values from inputs to the results. Requirements are
838 /// 1) All iterator types are parallel
839 /// 2) The body contains just a yield operation with the yielded values being
840 /// the arguments corresponding to the operands.
841 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
843 
844  LogicalResult matchAndRewrite(GenericOp genericOp,
845  PatternRewriter &rewriter) const override {
846  // Check all indexing maps are identity.
847  if (llvm::any_of(genericOp.getIndexingMaps(),
848  [](AffineMap map) { return !map.isIdentity(); }))
849  return failure();
850 
851  // Check that the body of the linalg operation is just a linalg.yield
852  // operation.
853  Block &body = genericOp.region().front();
854  if (!llvm::hasSingleElement(body))
855  return failure();
856  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
857  if (!yieldOp)
858  return failure();
859 
860  // In the buffer case, we need to check exact buffer equality.
861  if (genericOp.hasBufferSemantics()) {
862  if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
863  genericOp.getInputOperand(0)->get() ==
864  genericOp.getOutputOperand(0)->get()) {
865  rewriter.eraseOp(genericOp);
866  return success();
867  }
868  return failure();
869  }
870 
871  // Get the argument number of the returned values. That is the operand
872  // number to use for replacing uses of this operation.
873  SmallVector<Value> returnedArgs;
874  for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) {
875  auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
876  if (!yieldArg || yieldArg.getOwner() != &body)
877  return failure();
878  unsigned argumentNumber = yieldArg.getArgNumber();
879  Value returnedArg = genericOp->getOperand(argumentNumber);
880  Type resultType = genericOp->getResult(yieldVal.index()).getType();
881  // The input can have a different type than the result, e.g. a dynamic
882  // input dimension can be turned into a static output dimension.
883  if (returnedArg.getType() != resultType)
884  returnedArg = rewriter.create<tensor::CastOp>(genericOp.getLoc(),
885  resultType, returnedArg);
886  returnedArgs.push_back(returnedArg);
887  }
888 
889  if (returnedArgs.size() != genericOp->getNumResults())
890  return failure();
891  rewriter.replaceOp(genericOp, returnedArgs);
892  return success();
893  }
894 };
895 } // namespace
896 
897 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
898  MLIRContext *context) {
899  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // InitTensorOp
904 //===----------------------------------------------------------------------===//
905 
906 void InitTensorOp::build(OpBuilder &b, OperationState &result,
907  ArrayRef<OpFoldResult> sizes, Type elementType,
908  ArrayRef<NamedAttribute> attrs) {
909  SmallVector<Value, 4> dynamicSizes;
910  SmallVector<int64_t, 4> staticSizes;
911  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
912  ShapedType::kDynamicSize);
913  auto resultType = RankedTensorType ::get(staticSizes, elementType);
914  build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
915  result.addAttributes(attrs);
916 }
917 
918 static LogicalResult verify(InitTensorOp op) {
919  RankedTensorType resultType = op.getType();
920  SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
921  op.static_sizes().cast<ArrayAttr>(),
922  [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
923 
924  if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(),
925  op.static_sizes(), op.sizes(),
926  ShapedType::isDynamic)))
927  return failure();
928 
929  if (op.static_sizes().size() != static_cast<unsigned>(resultType.getRank()))
930  return op->emitError("expected ")
931  << resultType.getRank() << " sizes values";
932 
933  Type expectedType = InitTensorOp::inferResultType(
934  staticSizes, resultType.getElementType(), resultType.getEncoding());
935  if (resultType != expectedType) {
936  return op.emitError("specified type ")
937  << resultType << " does not match the inferred type "
938  << expectedType;
939  }
940  return success();
941 }
942 
943 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
944  Type elementType, Attribute encoding) {
945  return RankedTensorType::get(staticSizes, elementType, encoding);
946 }
947 
948 namespace {
949 /// Change the type of the result of a `linalg.init_tensor` by making the result
950 /// type statically sized along dimension that in the original operation where
951 /// defined as dynamic, but the size was defined using a `constant` op. For
952 /// example
953 ///
954 /// %c5 = arith.constant 5: index
955 /// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
956 ///
957 /// to
958 ///
959 /// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
960 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
962 
963  LogicalResult matchAndRewrite(InitTensorOp op,
964  PatternRewriter &rewriter) const override {
965  SmallVector<Value, 4> dynamicSizes;
966  SmallVector<int64_t, 4> staticSizes;
967  for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
968  // If the size is already static, nothing to do.
969  if (!op.isDynamicSize(i)) {
970  staticSizes.push_back(op.getStaticSize(i));
971  continue;
972  }
973 
974  // If the size is dynamic but defined using a `constant` op, get the
975  // constant value to find the static size to use.
976  unsigned operandNum = op.getIndexOfDynamicSize(i);
977  Value sizeOperand = op.getOperand(operandNum);
978  if (auto constantIndexOp =
979  sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
980  staticSizes.push_back(constantIndexOp.value());
981  continue;
982  }
983 
984  // Fallback case. Keep the size dynamic.
985  dynamicSizes.push_back(sizeOperand);
986  staticSizes.push_back(ShapedType::kDynamicSize);
987  }
988  RankedTensorType newType =
989  RankedTensorType::get(staticSizes, op.getType().getElementType());
990  if (newType == op.getType())
991  return failure();
992  auto newOp =
993  rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
994  rewriter.getI64ArrayAttr(staticSizes));
995  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
996  return success();
997  }
998 };
999 } // namespace
1000 
1001 namespace {
1002 /// Since `init_tensor` operation creates a tensor needed only for its shape, a
1003 /// slice of this is also needed only for its shape. The result can be
1004 /// replaced by a new init_tensor operation of the same size as the extract
1005 /// slice op.
1006 struct FoldInitTensorWithExtractSliceOp
1007  : public OpRewritePattern<tensor::ExtractSliceOp> {
1009 
1010  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1011  PatternRewriter &rewriter) const override {
1012  if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
1013  return failure();
1014  // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
1015  // as well as its result type.
1016  rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
1017  sliceOp, sliceOp.sizes(),
1018  sliceOp.result().getType().cast<RankedTensorType>().getShape(),
1019  sliceOp.getSourceType().getElementType());
1020  return success();
1021  }
1022 };
1023 
1024 template <typename TensorReshapeOp>
1025 struct FoldInitTensorWithTensorReshapeOp
1026  : public OpRewritePattern<TensorReshapeOp> {
1028 
1029  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1030  PatternRewriter &rewriter) const override {
1031  if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
1032  return failure();
1033  Location loc = reshapeOp.getLoc();
1034  ReifiedRankedShapedTypeDims resultShapes;
1035  ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
1036  cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
1037  if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
1038  resultShapes)) ||
1039  !llvm::hasSingleElement(resultShapes))
1040  return failure();
1041  Value initTensor = rewriter.create<InitTensorOp>(
1042  loc, getAsOpFoldResult(resultShapes[0]),
1043  reshapeOp.getResultType().getElementType());
1044  if (initTensor.getType() != reshapeOp.getResultType()) {
1045  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1046  reshapeOp, reshapeOp.getResultType(), initTensor);
1047  } else {
1048  rewriter.replaceOp(reshapeOp, initTensor);
1049  }
1050  return success();
1051  }
1052 };
1053 
1054 struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
1056 
1057  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1058  PatternRewriter &rewriter) const override {
1059  Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1060  auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>();
1061  if (!initTensorOp || !maybeConstantIndex)
1062  return failure();
1063  if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
1064  return failure();
1065  rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
1066  return success();
1067  }
1068 };
1069 } // namespace
1070 
1071 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
1072  MLIRContext *context) {
1073  results.add<FoldInitTensorWithDimOp, FoldInitTensorWithExtractSliceOp,
1074  FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
1075  FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
1076  ReplaceStaticShapeDims>(context);
1077 }
1078 
1079 LogicalResult InitTensorOp::reifyResultShapes(
1080  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1081  auto shapes = llvm::to_vector<4>(llvm::map_range(
1082  llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
1083  if (isDynamicSize(dim))
1084  return getDynamicSize(dim);
1085  return builder.create<arith::ConstantIndexOp>(getLoc(),
1086  getStaticSize(dim));
1087  }));
1088  reifiedReturnShapes.emplace_back(std::move(shapes));
1089  return success();
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // YieldOp
1094 //===----------------------------------------------------------------------===//
1095 
1096 static void print(OpAsmPrinter &p, linalg::YieldOp op) {
1097  if (op.getNumOperands() > 0)
1098  p << ' ' << op.getOperands();
1099  p.printOptionalAttrDict(op->getAttrs());
1100  if (op.getNumOperands() > 0)
1101  p << " : " << op.getOperandTypes();
1102 }
1103 
1106  SmallVector<Type, 2> types;
1107  llvm::SMLoc loc = parser.getCurrentLocation();
1108  return failure(parser.parseOperandList(opInfo) ||
1109  parser.parseOptionalAttrDict(result.attributes) ||
1110  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1111  parser.resolveOperands(opInfo, types, loc, result.operands));
1112 }
1113 
1114 // Check the operand number and types must match the element types of the
1115 // LinalgOp interface's shaped operands.
1116 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1117  if (op.getNumOperands() != linalgOp.getNumOutputs())
1118  return op.emitOpError("expected number of yield values (")
1119  << linalgOp.getNumOutputs()
1120  << ") to match the number of operands of the enclosing "
1121  << "LinalgOp (" << op.getNumOperands() << ")";
1122 
1123  for (OpOperand &opOperand : op->getOpOperands()) {
1124  OpOperand *outputOperand =
1125  linalgOp.getOutputOperand(opOperand.getOperandNumber());
1126  Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1127  if (opOperand.get().getType() != elementType)
1128  return op.emitOpError("type of yield operand ")
1129  << (opOperand.getOperandNumber() + 1) << " ("
1130  << opOperand.get().getType() << ") doesn't match "
1131  << "the element type of the enclosing linalg.generic op ("
1132  << elementType << ")";
1133  }
1134  return success();
1135 }
1136 
1137 static LogicalResult verify(linalg::YieldOp op) {
1138  auto *parentOp = op->getParentOp();
1139  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1140  return op.emitOpError("expected single non-empty parent region");
1141 
1142  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1143  return verifyYield(op, cast<LinalgOp>(parentOp));
1144 
1145  if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
1146  // Check if output args with tensor types match results types.
1147  SmallVector<Value, 2> tensorOuts;
1148  llvm::copy_if(
1149  tiledLoopOp.outputs(), std::back_inserter(tensorOuts),
1150  [&](Value out) { return out.getType().isa<RankedTensorType>(); });
1151  if (tensorOuts.size() != op.values().size())
1152  return op.emitOpError("expected number of tensor output args = ")
1153  << tensorOuts.size() << " to match the number of yield operands = "
1154  << op.values().size();
1155 
1156  TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts));
1157  for (auto &item :
1158  llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) {
1159  Type outType, resultType;
1160  unsigned index = item.index();
1161  std::tie(outType, resultType) = item.value();
1162  if (outType != resultType)
1163  return op.emitOpError("expected yield operand ")
1164  << index << " with type = " << resultType
1165  << " to match output arg type = " << outType;
1166  }
1167  return success();
1168  }
1169  return op.emitOpError("expected parent op with LinalgOp interface");
1170 }
1171 
1172 //===----------------------------------------------------------------------===//
1173 // TiledLoopOp
1174 //===----------------------------------------------------------------------===//
1175 
1176 void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
1177  ValueRange lowerBounds, ValueRange upperBounds,
1178  ValueRange steps, ValueRange inputs, ValueRange outputs,
1179  ArrayAttr iteratorTypes,
1182  bodyBuilderFn) {
1183  build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
1184  iteratorTypes, llvm::None, bodyBuilderFn);
1185 }
1186 
1187 void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
1188  ValueRange lowerBounds, ValueRange upperBounds,
1189  ValueRange steps, ValueRange inputs, ValueRange outputs,
1190  ArrayAttr iteratorTypes,
1191  Optional<ArrayAttr> distributionTypes,
1194  bodyBuilderFn) {
1195  result.addOperands(lowerBounds);
1196  result.addOperands(upperBounds);
1197  result.addOperands(steps);
1198  result.addOperands(inputs);
1199  result.addOperands(outputs);
1200  result.addAttribute(
1201  TiledLoopOp::getOperandSegmentSizeAttr(),
1202  builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
1203  static_cast<int32_t>(upperBounds.size()),
1204  static_cast<int32_t>(steps.size()),
1205  static_cast<int32_t>(inputs.size()),
1206  static_cast<int32_t>(outputs.size())}));
1207  result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
1208 
1209  if (distributionTypes.hasValue())
1211  distributionTypes.getValue());
1212 
1213  // Add output types for `RankedTensorType` output arguments.
1214  for (Value output : outputs) {
1215  Type outputType = output.getType();
1216  if (outputType.isa<RankedTensorType>())
1217  result.addTypes(outputType);
1218  }
1219 
1220  OpBuilder::InsertionGuard guard(builder);
1221  unsigned numIVs = steps.size();
1222  SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
1223  SmallVector<Location, 8> argLocs(numIVs, result.location);
1224  for (Value input : inputs) {
1225  argTypes.push_back(input.getType());
1226  argLocs.push_back(input.getLoc());
1227  }
1228  for (Value output : outputs) {
1229  argTypes.push_back(output.getType());
1230  argLocs.push_back(output.getLoc());
1231  }
1232  Region *bodyRegion = result.addRegion();
1233  Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
1234 
1235  if (bodyBuilderFn) {
1236  builder.setInsertionPointToStart(bodyBlock);
1237  bodyBuilderFn(builder, result.location,
1238  bodyBlock->getArguments().take_front(numIVs),
1239  bodyBlock->getArguments().slice(numIVs, inputs.size()),
1240  bodyBlock->getArguments().take_back(outputs.size()));
1241  TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
1242  }
1243 }
1244 
1245 static void print(OpAsmPrinter &p, TiledLoopOp op) {
1246  p << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to ("
1247  << op.upperBound() << ") step (" << op.step() << ")";
1248 
1249  if (!op.inputs().empty()) {
1250  p << " ins (";
1251  llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p,
1252  [&](auto it) {
1253  p << std::get<0>(it) << " = " << std::get<1>(it)
1254  << ": " << std::get<1>(it).getType();
1255  });
1256  p << ")";
1257  }
1258  if (!op.outputs().empty()) {
1259  p << " outs (";
1260  llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p,
1261  [&](auto it) {
1262  p << std::get<0>(it) << " = " << std::get<1>(it)
1263  << ": " << std::get<1>(it).getType();
1264  });
1265  p << ")";
1266  }
1267 
1268  if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
1269  return attr.cast<StringAttr>().getValue() !=
1271  }))
1272  p << " iterators" << op.iterator_types();
1273 
1274  if (op.distribution_types().hasValue())
1275  p << " distribution" << op.distribution_types().getValue();
1276 
1277  p << ' ';
1278  p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1280  op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(),
1283 }
1284 
1286  OperationState &result) {
1287  auto &builder = parser.getBuilder();
1288  // Parse an opening `(` followed by induction variables followed by `)`
1290  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1292  return failure();
1293 
1294  // Parse loop bounds.
1296  if (parser.parseEqual() ||
1297  parser.parseOperandList(lower, ivs.size(),
1299  parser.resolveOperands(lower, builder.getIndexType(), result.operands))
1300  return failure();
1301 
1303  if (parser.parseKeyword("to") ||
1304  parser.parseOperandList(upper, ivs.size(),
1306  parser.resolveOperands(upper, builder.getIndexType(), result.operands))
1307  return failure();
1308 
1309  // Parse step values.
1311  if (parser.parseKeyword("step") ||
1312  parser.parseOperandList(steps, ivs.size(),
1314  parser.resolveOperands(steps, builder.getIndexType(), result.operands))
1315  return failure();
1316 
1317  // Parse input tensors.
1318  SmallVector<OpAsmParser::OperandType, 4> inputs, inputRegionArgs;
1319  SmallVector<Type, 4> inputTypes;
1320  if (succeeded(parser.parseOptionalKeyword("ins"))) {
1321  llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
1322 
1323  if (parser.parseAssignmentListWithTypes(inputRegionArgs, inputs,
1324  inputTypes))
1325  return failure();
1326 
1327  if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
1328  result.operands))
1329  return failure();
1330  }
1331 
1332  // Parse output tensors.
1333  SmallVector<OpAsmParser::OperandType, 4> outputs, outputRegionArgs;
1334  SmallVector<Type, 4> outputTypes;
1335  if (succeeded(parser.parseOptionalKeyword("outs"))) {
1336  llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
1337 
1338  if (parser.parseAssignmentListWithTypes(outputRegionArgs, outputs,
1339  outputTypes))
1340  return failure();
1341 
1342  if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
1343  result.operands))
1344  return failure();
1345  for (Type outputType : outputTypes)
1346  if (outputType.isa<RankedTensorType>())
1347  result.addTypes(outputType);
1348  }
1349 
1350  // Parse attributes.
1351  SmallVector<Attribute, 4> iterTypes, distributionTypes;
1352  auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
1353  if (succeeded(parser.parseOptionalKeyword(keyword))) {
1354  StringAttr attr;
1355 
1356  if (parser.parseLSquare() || parser.parseAttribute(attr))
1357  return failure();
1358  attrs->push_back(attr);
1359  for (int i = 1, e = ivs.size(); i < e; ++i) {
1360  if (parser.parseComma() || parser.parseAttribute(attr))
1361  return failure();
1362  attrs->push_back(attr);
1363  }
1364  if (parser.parseRSquare())
1365  return failure();
1366  }
1367  return success();
1368  };
1369  if (failed(parseAttr("iterators", &iterTypes)) ||
1370  failed(parseAttr("distribution", &distributionTypes)))
1371  return failure();
1372 
1373  // Set all loop iterator types to "parallel" if they are not printed in IR.
1374  if (iterTypes.empty()) {
1375  auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName());
1376  iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
1377  }
1379  builder.getArrayAttr(iterTypes));
1380  if (!distributionTypes.empty())
1382  builder.getArrayAttr(distributionTypes));
1383  result.addAttribute(
1384  TiledLoopOp::getOperandSegmentSizeAttr(),
1385  builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
1386  static_cast<int32_t>(upper.size()),
1387  static_cast<int32_t>(steps.size()),
1388  static_cast<int32_t>(inputs.size()),
1389  static_cast<int32_t>(outputs.size())}));
1390 
1391  // Parse the body.
1392  Region *body = result.addRegion();
1393 
1394  SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
1395  regionTypes.append(inputTypes);
1396  regionTypes.append(outputTypes);
1397 
1399  regionArgs.append(inputRegionArgs);
1400  regionArgs.append(outputRegionArgs);
1401 
1402  if (parser.parseRegion(*body, regionArgs, regionTypes))
1403  return failure();
1404 
1405  // Parse optional attributes.
1406  parser.parseOptionalAttrDict(result.attributes);
1407 
1408  return success();
1409 }
1410 
1411 Region &TiledLoopOp::getLoopBody() { return region(); }
1412 
1413 LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1414  for (auto *op : ops)
1415  op->moveBefore(*this);
1416  return success();
1417 }
1418 
1419 bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
1420  return !region().isAncestor(value.getParentRegion());
1421 }
1422 
1423 static LogicalResult verify(TiledLoopOp op) {
1424  // Check if iterator types are provided for every loop dimension.
1425  if (op.iterator_types().size() != op.getNumLoops())
1426  return op.emitOpError("expected iterator types array attribute size = ")
1427  << op.iterator_types().size()
1428  << " to match the number of loops = " << op.getNumLoops();
1429 
1430  // Check if types of input arguments match region args types.
1431  for (auto &item :
1432  llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) {
1433  Value input, inputRegionArg;
1434  unsigned index = item.index();
1435  std::tie(input, inputRegionArg) = item.value();
1436  if (input.getType() != inputRegionArg.getType())
1437  return op.emitOpError("expected input arg ")
1438  << index << " with type = " << input.getType()
1439  << " to match region arg " << index + op.getNumLoops()
1440  << " type = " << inputRegionArg.getType();
1441  }
1442 
1443  // Check if types of input arguments match region args types.
1444  for (auto &item :
1445  llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) {
1446  Value output, outputRegionArg;
1447  unsigned index = item.index();
1448  std::tie(output, outputRegionArg) = item.value();
1449  if (output.getType() != outputRegionArg.getType())
1450  return op.emitOpError("expected output arg ")
1451  << index << " with type = " << output.getType()
1452  << " to match region arg "
1453  << index + op.getNumLoops() + op.inputs().size()
1454  << " type = " << outputRegionArg.getType();
1455  }
1456  return success();
1457 }
1458 
1459 namespace {
1460 
1461 static constexpr int64_t kNoMatch = -1;
1462 
1463 // Folds away TiledLoopOp inputs if they have no uses within the body.
1464 //
1465 // Example:
1466 //
1467 // %0 = linalg.tiled_loop ... ins (%in_ = %in: tensor<...>,
1468 // %in_buf_ = %in_buf: memref<...>) {...}
1469 // Becomes
1470 //
1471 // linalg.tiled_loop ... ins (%in_buf_ = %in_buf: memref<...>) {...}
1472 struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
1474 
1475  LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
1476  PatternRewriter &rewriter) const final {
1477  SmallVector<Value, 2> newInputs, regionInputTensorArgs;
1478  // Store ids of the corresponding old and new input operands.
1479  SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(),
1480  kNoMatch);
1481  for (const auto &en : llvm::enumerate(
1482  llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) {
1483  Value in, bbArg;
1484  size_t index = en.index();
1485  std::tie(in, bbArg) = en.value();
1486  if (!bbArg.use_empty()) {
1487  oldInputIdToNew[index] = newInputs.size();
1488  newInputs.push_back(in);
1489  }
1490  }
1491  if (newInputs.size() == tiledLoop.inputs().size())
1492  return failure();
1493  Location loc = tiledLoop.getLoc();
1494  auto newTiledLoop = rewriter.create<TiledLoopOp>(
1495  loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
1496  newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(),
1497  tiledLoop.distribution_types());
1498 
1499  // Clone the region.
1501  bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
1502  bvm.map(tiledLoop.getRegionOutputArgs(),
1503  newTiledLoop.getRegionOutputArgs());
1504  for (const auto &en : llvm::enumerate(oldInputIdToNew))
1505  if (en.value() != kNoMatch)
1506  bvm.map(tiledLoop.getRegionInputArgs()[en.index()],
1507  newTiledLoop.getRegionInputArgs()[en.value()]);
1508  OpBuilder innerBuilder =
1509  OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
1510  for (auto &op : *tiledLoop.getBody())
1511  innerBuilder.clone(op, bvm);
1512  rewriter.replaceOp(tiledLoop, newTiledLoop.getResults());
1513 
1514  return success();
1515  }
1516 };
1517 
1518 } // namespace
1519 
1520 /// A simple, conservative analysis to determine if the loop is shape
1521 /// conserving. I.e., the type of the arg-th yielded value is the same as the
1522 /// type of the corresponding basic block argument of the loop.
1523 /// Note: This function handles only simple cases. Expand as needed.
1524 static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
1525  auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
1526  if (yieldOp.values().empty())
1527  // Tiled loop either has no outputs or is a "memref-based version". In
1528  // either case, the loop is shape conserving.
1529  return true;
1530  assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
1531  "arg is out of bounds");
1532  Value value = yieldOp.values()[arg];
1533  while (value) {
1534  if (value == loopOp.getRegionOutputArgs()[arg])
1535  return true;
1536  OpResult opResult = value.dyn_cast<OpResult>();
1537  if (!opResult)
1538  return false;
1539 
1540  using tensor::InsertSliceOp;
1542  .template Case<InsertSliceOp>(
1543  [&](InsertSliceOp op) { return op.dest(); })
1544  .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
1545  return isShapePreserving(loopOp, opResult.getResultNumber())
1546  ? loopOp.outputs()[opResult.getResultNumber()]
1547  : Value();
1548  })
1549  .Default([&](auto op) { return Value(); });
1550  }
1551  return false;
1552 }
1553 
1554 namespace {
1555 
1556 /// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
1557 /// to dim(y) where `y` is the initial input/output value of the argument.
1558 ///
1559 /// E.g.:
1560 /// %y = ... : tensor<...>
1561 /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
1562 /// tensor.dim %x, %c0 : tensor<...>
1563 /// }
1564 ///
1565 /// is folded to:
1566 /// %y = ... : tensor<...>
1567 /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
1568 /// tensor.dim %y, %c0 : tensor<...>
1569 /// }
1570 ///
1571 /// Note: Dim ops are folded only if it can be proven that the runtime type of
1572 /// the yielded value (in case of outputs) does not change with loop iterations.
1573 template <typename OpTy>
1574 struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
1576 
1577  LogicalResult matchAndRewrite(OpTy dimOp,
1578  PatternRewriter &rewriter) const final {
1579  auto src = dimOp.source().template dyn_cast<BlockArgument>();
1580  if (!src)
1581  return failure();
1582  auto loopOp =
1583  dyn_cast<TiledLoopOp>(src.getOwner()->getParent()->getParentOp());
1584  if (!loopOp)
1585  return failure();
1586  unsigned numLoops = loopOp.getNumLoops();
1587  unsigned numInputArgs = loopOp.getRegionInputArgs().size();
1588  if (src.getArgNumber() >= numInputArgs + numLoops &&
1589  !isShapePreserving(loopOp,
1590  src.getArgNumber() - numInputArgs - numLoops))
1591  return failure();
1592 
1593  auto inputArgs = loopOp.getRegionInputArgs();
1594  auto it1 = llvm::find(inputArgs, src);
1595  if (it1 != inputArgs.end()) {
1596  rewriter.updateRootInPlace(dimOp, [&] {
1597  dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]);
1598  });
1599  return success();
1600  }
1601 
1602  auto outputArgs = loopOp.getRegionOutputArgs();
1603  auto it2 = llvm::find(outputArgs, src);
1604  if (it2 != outputArgs.end()) {
1605  rewriter.updateRootInPlace(dimOp, [&] {
1606  dimOp.sourceMutable().assign(
1607  loopOp.outputs()[it2 - outputArgs.begin()]);
1608  });
1609  return success();
1610  }
1611 
1612  return failure();
1613  }
1614 };
1615 
1616 /// Fold dim(r) where `r` is the result of a TiledLoopOp to dim(y) where `y`
1617 /// is the initial output value of the loop.
1618 ///
1619 /// E.g.:
1620 /// %y = ... : tensor<...>
1621 /// %r = linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
1622 /// ...
1623 /// }
1624 /// %0 = tensor.dim %r, %c0 : tensor<...>
1625 ///
1626 /// is folded to:
1627 /// %y = ... : tensor<...>
1628 /// linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
1629 /// ...
1630 /// }
1631 /// %0 = tensor.dim %y, %c0 : tensor<...>
1632 ///
1633 /// Note: Dim ops are folded only if it can be proven that the runtime type of
1634 /// the yielded value (in case of outputs) does not change with loop iterations.
1635 template <typename OpTy>
1636 struct DimOfTiledLoopResultFolder : public OpRewritePattern<OpTy> {
1638 
1639  LogicalResult matchAndRewrite(OpTy dimOp,
1640  PatternRewriter &rewriter) const final {
1641  auto loopOp = dimOp.source().template getDefiningOp<TiledLoopOp>();
1642  if (!loopOp)
1643  return failure();
1644  auto opResult = dimOp.source().template cast<OpResult>();
1645  unsigned resultNumber = opResult.getResultNumber();
1646  if (!isShapePreserving(loopOp, resultNumber))
1647  return failure();
1648  rewriter.updateRootInPlace(dimOp, [&]() {
1649  dimOp.sourceMutable().assign(loopOp.outputs()[resultNumber]);
1650  });
1651  return success();
1652  }
1653 };
1654 
1655 // Folds away TiledLoopOp output tensors when the following conditions are met:
1656 // * result of `linalg.tiled_loop` has no uses
1657 // * output tensor is the argument of `linalg.yield`
1658 //
1659 // Example:
1660 //
1661 // %0 = linalg.tiled_loop ... outs (%o_ = %out: tensor<...>,
1662 // %obuf_ = %out_buf: memref<...>) {
1663 // ...
1664 // linalg.yield %o_ : tensor ...
1665 // }
1666 //
1667 // Becomes
1668 //
1669 // linalg.tiled_loop ... outs (%obuf_ = %out_buf: memref<...>) {
1670 // ...
1671 // linalg.yield
1672 // }
1673 struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
1675 
1676  LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
1677  PatternRewriter &rewriter) const final {
1678  if (tiledLoop.getNumResults() == 0)
1679  return failure();
1680 
1681  Block *block = tiledLoop.getBody();
1682  auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
1683 
1684  // Match the pattern and collect output buffers that will replace the output
1685  // tensors and also the ops that will be ignored when cloning the body.
1686  SmallVector<Value, 2> newOutputOperands, newYieldArgs;
1687  int resultId = 0;
1688  // Store ids of the corresponding old and new output operands.
1689  SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
1690  kNoMatch);
1691  // Store ids of the corresponding old and new results.
1692  SmallVector<int64_t, 2> oldResultIdToNew(tiledLoop.getNumResults(),
1693  kNoMatch);
1694  SmallVector<Value, 2> resultReplacement(tiledLoop.getNumResults());
1695  for (const auto &en : llvm::enumerate(
1696  llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
1697  size_t index = en.index();
1698  Value out = std::get<0>(en.value());
1699  Value outRegionArg = std::get<1>(en.value());
1700 
1701  if (!out.getType().isa<RankedTensorType>()) {
1702  oldOutputIdToNew[index] = newOutputOperands.size();
1703  newOutputOperands.push_back(out);
1704  continue;
1705  }
1706  Value result = tiledLoop.getResult(resultId);
1707  Value yieldArg = yieldOp.getOperand(resultId);
1708  if (yieldArg != outRegionArg || !result.use_empty()) {
1709  oldOutputIdToNew[index] = newOutputOperands.size();
1710  oldResultIdToNew[resultId] = newYieldArgs.size();
1711  resultReplacement[resultId] = out;
1712  newOutputOperands.push_back(out);
1713  newYieldArgs.push_back(yieldArg);
1714  }
1715  ++resultId;
1716  }
1717  if (newOutputOperands.size() == tiledLoop.outputs().size())
1718  return failure();
1719 
1720  Location loc = tiledLoop.getLoc();
1721  auto newTiledLoop = rewriter.create<TiledLoopOp>(
1722  loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
1723  tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(),
1724  tiledLoop.distribution_types());
1725 
1726  // Clone the region.
1728  bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
1729  bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs());
1730  for (const auto &en : llvm::enumerate(oldOutputIdToNew)) {
1731  if (en.value() != kNoMatch)
1732  bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
1733  newTiledLoop.getRegionOutputArgs()[en.value()]);
1734  else
1735  bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
1736  tiledLoop.outputs()[en.index()]);
1737  }
1738  OpBuilder innerBuilder =
1739  OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
1740  for (auto &op : tiledLoop.getBody()->without_terminator())
1741  innerBuilder.clone(op, bvm);
1742  innerBuilder.create<linalg::YieldOp>(
1743  loc, llvm::to_vector<2>(llvm::map_range(
1744  newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
1745 
1746  for (const auto &en : llvm::enumerate(oldResultIdToNew))
1747  if (en.value() != kNoMatch)
1748  resultReplacement[en.index()] = newTiledLoop.getResult(en.value());
1749  rewriter.replaceOp(tiledLoop, resultReplacement);
1750 
1751  return success();
1752  }
1753 };
1754 } // namespace
1755 
1756 void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1757  MLIRContext *context) {
1758  results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder,
1759  DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
1760  DimOfTiledLoopInsOutsFolder<memref::DimOp>,
1761  DimOfTiledLoopResultFolder<tensor::DimOp>,
1762  DimOfTiledLoopResultFolder<memref::DimOp>>(context);
1763 }
1764 
1765 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
1766  SmallVectorImpl<OpFoldResult> &) {
1767  return foldMemRefCastInTiledLoopOp(*this);
1768 }
1769 
1770 //===----------------------------------------------------------------------===//
1771 // IndexOp
1772 //===----------------------------------------------------------------------===//
1773 
1774 static LogicalResult verify(IndexOp op) {
1775  auto linalgOp = dyn_cast<LinalgOp>(op->getParentOp());
1776  if (!linalgOp)
1777  return op.emitOpError("expected parent op with LinalgOp interface");
1778  if (linalgOp.getNumLoops() <= op.dim())
1779  return op.emitOpError("expected dim (")
1780  << op.dim() << ") to be lower than the number of loops ("
1781  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
1782  return success();
1783 }
1784 
1785 /////// Operations corresponding to library calls defined with Tablegen ////////
1786 
1787 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1788 
1789 #define GET_OP_CLASSES
1790 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1791 
1792 #define GET_OP_CLASSES
1793 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1794 
1795 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
1796 /// Assumes `op` is a LinalgOp.
1797 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
1799  if (!cast<LinalgOp>(op).iterator_types())
1800  return;
1801 
1802  unsigned dim = 0;
1803  for (auto tn :
1804  cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
1805  if (tn == iteratorTypeName)
1806  res.push_back(dim);
1807  ++dim;
1808  }
1809 }
1810 
1812  unsigned rank,
1813  MLIRContext *context) {
1814  if (maybeMap)
1815  return maybeMap.getValue();
1816  if (rank == 0)
1817  return AffineMap::get(context);
1818  return AffineMap::getMultiDimIdentityMap(rank, context);
1819 }
1820 
1822 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1823  MLIRContext *context) {
1825  res.reserve(num);
1826  for (unsigned i = 0; i < num; ++i)
1827  res.push_back(getAffineDimExpr(startIdx++, context));
1828  return res;
1829 }
1830 
1833  auto rangeA = llvm::make_range(a.begin(), a.end());
1834  auto rangeB = llvm::make_range(b.begin(), b.end());
1835  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1836  return llvm::to_vector<4>(concatRanges);
1837 }
1838 
1839 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1840  if (auto memref = t.dyn_cast<MemRefType>()) {
1841  ss << "view";
1842  for (auto size : memref.getShape())
1843  if (size < 0)
1844  ss << "sx";
1845  else
1846  ss << size << "x";
1847  appendMangledType(ss, memref.getElementType());
1848  } else if (auto vec = t.dyn_cast<VectorType>()) {
1849  ss << "vector";
1850  llvm::interleave(
1851  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1852  appendMangledType(ss, vec.getElementType());
1853  } else if (t.isSignlessIntOrIndexOrFloat()) {
1854  ss << t;
1855  } else {
1856  llvm_unreachable("Invalid type for linalg library name mangling");
1857  }
1858 }
1859 
1861  assert(isa<LinalgOp>(op));
1862  std::string name(op->getName().getStringRef().str());
1863  name.reserve(128);
1864  std::replace(name.begin(), name.end(), '.', '_');
1865  llvm::raw_string_ostream ss(name);
1866  ss << "_";
1867  auto types = op->getOperandTypes();
1868  llvm::interleave(
1869  types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
1870  [&]() { ss << "_"; });
1871  return ss.str();
1872 }
1873 
1874 //===----------------------------------------------------------------------===//
1875 // Support for named Linalg ops defined in ods-gen.
1876 //===----------------------------------------------------------------------===//
1877 
1878 /// Generic entry point to create the block for the region of a LinalgOp.
1879 /// This is used by both named structured ops created by ods-gen and by manually
1880 /// defined C++ ops.
1881 /// This is used by both builders and parsers.
1882 /// This function creates the block in the region with arguments corresponding
1883 /// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
1884 /// to be ShapedType.
1885 template <typename NamedStructuredOpType>
1887  OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
1888  TypeRange outputTypes,
1889  llvm::function_ref<void(unsigned, unsigned)> errorHandler) {
1890  assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
1891 
1892  // TODO: atm all operands go through getElementTypeOrSelf,
1893  // reconsider when we have evidence we need to.
1894  SmallVector<Type, 8> argTypes;
1895  SmallVector<Location, 8> argLocs;
1896  for (auto containers : {inputTypes, outputTypes}) {
1897  for (auto t : containers) {
1898  argTypes.push_back(getElementTypeOrSelf(t));
1899 
1900  // TODO: Pass in a proper location here.
1901  argLocs.push_back(opBuilder.getUnknownLoc());
1902  }
1903  }
1904 
1905  // RAII.
1906  OpBuilder::InsertionGuard guard(opBuilder);
1907  Block *body =
1908  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
1909  unsigned actual = body->getNumArguments();
1910  unsigned expected = NamedStructuredOpType::getNumRegionArgs();
1911  if (expected != actual) {
1912  if (errorHandler)
1913  errorHandler(expected, actual);
1914  return;
1915  }
1916 
1917  opBuilder.setInsertionPointToStart(body);
1918  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
1919  NamedStructuredOpType::regionBuilder(b, *body);
1920 
1921  // indexing_maps is an auto-generated method.
1922 
1923  // iterator_types is an auto-generated method.
1924 }
1925 
1926 /// Generic entry point to create both the region and the block of a LinalgOp.
1927 template <typename NamedStructuredOpType>
1929  OperationState &result,
1930  TypeRange inputTypes,
1931  TypeRange outputTypes) {
1932  Region &region = *result.addRegion();
1933  fillStructuredOpRegion<NamedStructuredOpType>(
1934  opBuilder, region, inputTypes, outputTypes,
1935  [&](unsigned expected, unsigned actual) {
1936  assert(expected != actual && "incorrect number of arguments");
1937  });
1938 }
1939 
1940 /// Common parsing used for both named structured ops created by ods-gen and by
1941 /// manually defined C++ ops. Does not handle regions.
1942 static ParseResult
1944  SmallVectorImpl<Type> &inputTypes,
1945  SmallVectorImpl<Type> &outputTypes) {
1946  llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc;
1947  SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands;
1948 
1949  parser.parseOptionalAttrDict(result.attributes);
1950 
1951  if (succeeded(parser.parseOptionalKeyword("ins"))) {
1952  if (parser.parseLParen())
1953  return failure();
1954 
1955  inputsOperandsLoc = parser.getCurrentLocation();
1956  if (parser.parseOperandList(inputsOperands) ||
1957  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
1958  return failure();
1959  }
1960 
1961  if (succeeded(parser.parseOptionalKeyword("outs"))) {
1962  outputsOperandsLoc = parser.getCurrentLocation();
1963  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
1964  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
1965  return failure();
1966  }
1967 
1968  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
1969  result.operands) ||
1970  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
1971  result.operands))
1972  return failure();
1973 
1974  result.addAttribute("operand_segment_sizes",
1975  parser.getBuilder().getI32VectorAttr(
1976  {static_cast<int32_t>(inputsOperands.size()),
1977  static_cast<int32_t>(outputsOperands.size())}));
1978  return success();
1979 }
1980 
1981 template <typename NamedStructuredOpType>
1983  NamedStructuredOpType op) {
1984  if (!op.inputs().empty())
1985  p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
1986  if (!op.outputs().empty())
1987  p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
1988 }
1989 
1990 //===----------------------------------------------------------------------===//
1991 // Specific parsing and printing for named structured ops created by ods-gen.
1992 //===----------------------------------------------------------------------===//
1993 
1994 template <typename NamedStructuredOpType>
1995 static ParseResult
1997  TypeRange inputTypes, TypeRange outputTypes) {
1998  ParseResult res = success();
1999  OpBuilder opBuilder(parser.getContext());
2000  // Resolve `captures` into `capturedValues` at parse time so we can build the
2001  // region with captures.
2002  SmallVector<Value> capturedValues;
2003  fillStructuredOpRegion<NamedStructuredOpType>(
2004  opBuilder, region, inputTypes, outputTypes,
2005  [&](unsigned expected, unsigned actual) {
2006  res = parser.emitError(
2007  parser.getCurrentLocation(),
2008  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
2009  "region expects {0} args, got {1}",
2010  expected, actual));
2011  region.front().dump();
2012  });
2013  return res;
2014 }
2015 
2016 static ParseResult
2018  SmallVectorImpl<Type> &resultTypes) {
2019  if (parser.parseOptionalArrowTypeList(resultTypes))
2020  return failure();
2021  return success();
2022 }
2023 
2024 template <typename NamedStructuredOpType>
2026  OperationState &result) {
2027  // TODO: Enable when ods-gen supports captures.
2028  SmallVector<Type, 1> inputTypes, outputTypes;
2029  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
2030  return failure();
2031 
2032  // TODO: consider merging results parsing into region parsing.
2033  // Need to wait for declarative assembly resolution to decide.
2034  SmallVector<Type, 1> outputTensorsTypes;
2035  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
2036  return failure();
2037  result.addTypes(outputTensorsTypes);
2038 
2039  std::unique_ptr<Region> region = std::make_unique<Region>();
2040  if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
2041  parser, *region, inputTypes, outputTypes))
2042  return failure();
2043  result.addRegion(std::move(region));
2044 
2045  return success();
2046 }
2047 
2049  TypeRange resultTypes) {
2050  if (resultTypes.empty())
2051  return;
2052  p.printOptionalArrowTypeList(resultTypes);
2053 }
2054 
2055 template <typename NamedStructuredOpType>
2056 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
2058  op->getAttrs(),
2059  /*elidedAttrs=*/{"operand_segment_sizes",
2060  // See generated code in mlir-linalg-yaml-gen.cpp
2061  "linalg.memoized_indexing_maps"});
2062 
2063  // Printing is shared with generic ops, except for the region and
2064  // attributes.
2066 
2067  // Results printing.
2068  printNamedStructuredOpResults(p, op.result_tensors().getTypes());
2069 
2070  // Region is elided.
2071 }
2072 
2073 template <typename NamedStructuredOpType>
2074 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
2075  return verifyGenericOp<NamedStructuredOpType>(op);
2076 }
2077 
2078 //===----------------------------------------------------------------------===//
2079 // Canonicalizers and Folders.
2080 //===----------------------------------------------------------------------===//
2081 
2082 namespace {
2083 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2085 
2086  LogicalResult matchAndRewrite(LinalgOp op,
2087  PatternRewriter &rewriter) const override {
2088  for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
2089  // Linalg "inputs" may be either tensor or memref type.
2090  // tensor<0xelt_type> is a convention that may not always mean
2091  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2092  auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
2093  if (!mt)
2094  continue;
2095  if (llvm::is_contained(op.getShape(opOperand), 0)) {
2096  rewriter.eraseOp(op);
2097  return success();
2098  }
2099  }
2100  return failure();
2101  }
2102 };
2103 
2104 struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
2106 
2107  LogicalResult matchAndRewrite(LinalgOp op,
2108  PatternRewriter &rewriter) const override {
2109  // If no operand comes from a tensor::CastOp and can be folded then fail.
2110  bool hasTensorCastOperand =
2111  llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
2112  if (opOperand->get().isa<BlockArgument>())
2113  return false;
2114  auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
2115  return castOp && canFoldIntoConsumerOp(castOp);
2116  });
2117  if (!hasTensorCastOperand)
2118  return failure();
2119 
2120  SmallVector<Type, 4> newResultTypes;
2121  newResultTypes.reserve(op->getNumResults());
2122  SmallVector<Value, 4> newOperands;
2123  newOperands.reserve(op->getNumOperands());
2124  // Inputs may fold.
2125  for (OpOperand *opOperand : op.getInputOperands()) {
2126  auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
2127  newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
2128  ? tensorCastOp.source()
2129  : opOperand->get());
2130  }
2131  // Init tensors may fold, in which case the resultType must also change.
2132  for (OpOperand *opOperand : op.getOutputOperands()) {
2133  auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
2134  bool fold = canFoldIntoConsumerOp(tensorCastOp);
2135  newOperands.push_back(fold ? tensorCastOp.getOperand()
2136  : opOperand->get());
2137  newResultTypes.push_back(newOperands.back().getType());
2138  }
2139  // Clone op.
2140  Operation *newOp =
2141  op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
2142  SmallVector<Value, 4> replacements;
2143  replacements.reserve(newOp->getNumResults());
2144  for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
2145  Value oldResult = std::get<0>(result);
2146  Value newResult = std::get<1>(result);
2147  if (newResult.getType() != oldResult.getType()) {
2148  replacements.push_back(rewriter.create<tensor::CastOp>(
2149  op->getLoc(), oldResult.getType(), newResult));
2150  } else {
2151  replacements.push_back(newResult);
2152  }
2153  }
2154  rewriter.replaceOp(op, replacements);
2155 
2156  return success();
2157  }
2158 };
2159 
2160 } // namespace
2161 
2162 #define LINALGOP_FOLDERS(XXX) \
2163  LogicalResult XXX::fold(ArrayRef<Attribute>, \
2164  SmallVectorImpl<OpFoldResult> &) { \
2165  return foldMemRefCast(*this); \
2166  }
2167 
2168 LINALGOP_FOLDERS(CopyOp)
2169 LINALGOP_FOLDERS(FillOp)
2170 LINALGOP_FOLDERS(GenericOp)
2171 
2172 // All named ops canonicalizers and folders are auto-generated in the
2173 // .cpp.inc.
2174 
2175 //===----------------------------------------------------------------------===//
2176 // LinalgDialect
2177 //===----------------------------------------------------------------------===//
2178 
2179 void LinalgDialect::getCanonicalizationPatterns(
2180  RewritePatternSet &results) const {
2181  results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
2182 }
2183 
2185  Attribute value, Type type,
2186  Location loc) {
2187  return builder.create<arith::ConstantOp>(loc, type, value);
2188 }
Location getUnknownLoc()
Definition: Builders.cpp:26
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
static void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, TypeRange outputTypes)
Generic entry point to create both the region and the block of a LinalgOp.
Definition: LinalgOps.cpp:1928
virtual ParseResult parseLParen()=0
Parse a ( token.
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
U cast() const
Definition: Attributes.h:123
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type)
FillOp region is elided when printing.
Definition: LinalgOps.cpp:505
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes)
Common parsing and printing used for both named structured ops created by ods-gen and by manually def...
Definition: LinalgOps.cpp:1943
This is a value defined by a result of an operation.
Definition: Value.h:423
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, TypeRange inputTypes, TypeRange outputTypes)
Specific parsing and printing for named structured ops created by ods-gen.
Definition: LinalgOps.cpp:1996
Block represents an ordered list of Operations.
Definition: Block.h:29
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
AffineMap extractOrIdentityMap(Optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank...
Definition: LinalgOps.cpp:1811
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:457
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition: Block.cpp:175
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static void printCommonStructuredOpParts(OpAsmPrinter &p, NamedStructuredOpType op)
Definition: LinalgOps.cpp:1982
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:2048
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type)
CopyOp region is elided when printing.
Definition: LinalgOps.cpp:410
operand_type_range getOperandTypes()
Definition: Operation.h:266
static void print(OpAsmPrinter &p, GenericOp op)
Definition: LinalgOps.cpp:624
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op)
Definition: LinalgOps.cpp:2074
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:215
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, ValueRange values, llvm::function_ref< bool(int64_t)> isDynamic)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
constexpr StringRef getIteratorTypesAttrName()
Attribute name for the StrArrayAttr which encodes the type of a structured op&#39;s iterators.
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:2017
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:310
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:161
virtual ParseResult parseComma()=0
Parse a , token.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
static constexpr const bool value
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:209
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:432
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:137
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
void assign(const_iterator in_start, const_iterator in_end)
Replaces the attributes with new list of attributes.
static DefaultResource * get()
Returns a unique instance for the given effect class.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:258
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:252
virtual ParseResult parseLSquare()=0
Parse a [ token.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:77
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, llvm::function_ref< void(unsigned, unsigned)> errorHandler=nullptr)
Forward declarations.
Definition: LinalgOps.cpp:1886
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
constexpr StringRef getDistributionTypesAttrName()
Attribute name for the StrArrayAttr which encodes the distribution type for linalg.tiled_loop.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual ParseResult parseRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:1860
static LogicalResult verifyGenericOp(GenericOpType op)
Definition: LinalgOps.cpp:738
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
void addOperands(ValueRange newOperands)
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:136
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, ValueRange results, ValueRange inputBuffers, ValueRange outputs)
Definition: LinalgOps.cpp:708
U dyn_cast() const
Definition: Types.h:244
static ParseResult parseTiledLoopOp(OpAsmParser &parser, OperationState &result)
Definition: LinalgOps.cpp:1285
unsigned getNumArguments()
Definition: Block.h:119
static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result)
Definition: LinalgOps.cpp:1104
Attributes are known-constant values of operations.
Definition: Attributes.h:24
U dyn_cast() const
Definition: Value.h:99
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:435
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void eraseArgument(unsigned index)
Erase the argument at &#39;index&#39; and remove it from the argument list.
Definition: Block.cpp:181
bool isIndex() const
Definition: Types.cpp:28
virtual ParseResult parseRParen()=0
Parse a ) token.
static void appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:1839
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:133
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn&#39;t have a listener...
Definition: Builders.h:251
void getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl< unsigned > &res)
Return the dims that are iteratorTypeName loops in the LinalgOp op.
Definition: LinalgOps.cpp:1797
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
Parens surrounding zero or more operands.
BlockArgListType getArguments()
Definition: Block.h:76
ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType, Type outputType)
Definition: LinalgOps.cpp:401
This class represents an argument of a Block.
Definition: Value.h:298
This class represents a specific instance of an effect.
virtual ParseResult parseRSquare()=0
Parse a ] token.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:91
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result)
Definition: LinalgOps.cpp:667
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:202
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
ParseResult parseAssignmentListWithTypes(SmallVectorImpl< OperandType > &lhs, SmallVectorImpl< OperandType > &rhs, SmallVectorImpl< Type > &types)
Parse a list of assignments of the form (x1 = y1 : type1, x2 = y2 : type2, ...)
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op)
Definition: LinalgOps.cpp:2056
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
Region * addRegion()
Create a region that should be attached to the operation.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1116
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op)
This is a specialization of foldMemRefCast used for patterns of the form tiled_loop(memrefcast(%src))...
Definition: LinalgOps.cpp:113
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
IndexType getIndexType()
Definition: Builders.cpp:48
iterator end()
Definition: Region.h:56
ImplicitLocOpBuilder maintains a &#39;current location&#39;, allowing use of the create<> method without spec...
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result)
Definition: LinalgOps.cpp:2025
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
This class represents an operand of an operation.
Definition: Value.h:249
static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg)
A simple, conservative analysis to determine if the loop is shape conserving.
Definition: LinalgOps.cpp:1524
ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType, Type outputType)
Definition: LinalgOps.cpp:496
Operation * clone(BlockAndValueMapping &mapper)
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:564
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
static LogicalResult foldMemRefCast(Operation *op)
This is a common class used for patterns of the form someop(memrefcast(%src)) -> someop(%src) It fold...
Definition: LinalgOps.cpp:96
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:1822
#define LINALGOP_FOLDERS(XXX)
Definition: LinalgOps.cpp:2162
virtual ParseResult parseEqual()=0
Parse a = token.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with &#39;numDims&#39; identity result dim exprs.
Definition: AffineMap.cpp:244
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
bool isa() const
Definition: Types.h:234
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:61
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
result_range getResults()
Definition: Operation.h:284
This class helps build Operations.
Definition: Builders.h:177
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
This class provides an abstraction over the different types of ranges over Values.
virtual ParseResult parseRegionArgumentList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more region arguments with a specified surrounding delimiter, and an optional required ...
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:370
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:1831
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:246