MLIR  18.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
31 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/PatternMatch.h"
36 
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/StringSet.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 
46 using namespace mlir;
47 using namespace mlir::linalg;
48 
49 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
51  int64_t dim) {
52  auto type = cast<ShapedType>(v.getType());
53  if (!type.isDynamicDim(dim))
54  return builder.getIndexAttr(type.getDimSize(dim));
55 
56  return getAsOpFoldResult(
58  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
59  return builder.create<tensor::DimOp>(loc, v, dim);
60  })
61  .Case<MemRefType>([&](MemRefType t) -> Value {
62  return builder.create<memref::DimOp>(loc, v, dim);
63  }));
64 }
65 
66 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
67 /// `source`.
68 static Value getSlice(OpBuilder &b, Location loc, Value source,
69  ArrayRef<OpFoldResult> offsets,
71  ArrayRef<OpFoldResult> strides) {
72  return TypeSwitch<Type, Value>(source.getType())
73  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
74  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
75  strides);
76  })
77  .Case<MemRefType>([&](MemRefType type) -> Value {
78  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
79  strides);
80  })
81  .Default([&](Type t) { return nullptr; });
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // Helper functions
86 //===----------------------------------------------------------------------===//
87 
89  int64_t dim) {
90  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
91  return b.createOrFold<memref::DimOp>(loc, source, dim);
92  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
93  return b.createOrFold<tensor::DimOp>(loc, source, dim);
94  llvm_unreachable("Expected MemRefType or TensorType");
95 }
96 
98  int64_t dim) {
99  auto shapedType = llvm::cast<ShapedType>(source.getType());
100  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
101  return createOrFoldDimOp(b, loc, source, dim);
102  return b.getIndexAttr(shapedType.getDimSize(dim));
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Support for named Linalg ops defined in ods-gen.
107 //===----------------------------------------------------------------------===//
108 
111 
112 /// Fills the region of a structured operation using the provided
113 /// `regionBuilder`. The method is used by both named structured ops created by
114 /// ods-gen and by manually defined C++ ops. It is called by both builders and
115 /// parsers and creates a block with arguments corresponding to the elemental
116 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
117 /// ShapedType.
118 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
119  TypeRange inputTypes, TypeRange outputTypes,
121  RegionBuilderFn regionBuilder) {
122  assert(llvm::all_of(outputTypes,
123  [](Type t) { return llvm::isa<ShapedType>(t); }));
124 
125  SmallVector<Type, 8> argTypes;
126  SmallVector<Location, 8> argLocs;
127  for (auto containers : {inputTypes, outputTypes}) {
128  for (auto t : containers) {
129  argTypes.push_back(
130  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
131 
132  // TODO: Pass in a proper location here.
133  argLocs.push_back(opBuilder.getUnknownLoc());
134  }
135  }
136 
137  // RAII.
138  OpBuilder::InsertionGuard guard(opBuilder);
139  Block *body =
140  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
141 
142  opBuilder.setInsertionPointToStart(body);
143  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
144  regionBuilder(b, *body, attrs);
145 
146  // indexing_maps is an auto-generated method.
147 
148  // iterator_types is an auto-generated method.
149 }
150 
151 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
152 /// The result types are derived automatically if `resultTensorTypes` is none.
153 /// The body of the operation is filled using `regionBuilder`. All ods-gen
154 /// created structured operations use the method to implement their builders.
156  std::optional<TypeRange> resultTensorTypes,
157  ValueRange inputs, ValueRange outputs,
158  ArrayRef<NamedAttribute> attributes,
159  RegionBuilderFn regionBuilder) {
160  // Derive the result types if needed.
161  SmallVector<Type> derivedResultTypes =
162  resultTensorTypes.value_or(TypeRange());
163  if (!resultTensorTypes)
164  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
165  [](Type type) { return llvm::isa<RankedTensorType>(type); });
166 
167  state.addOperands(inputs);
168  state.addOperands(outputs);
169  state.addTypes(derivedResultTypes);
170  state.addAttributes(attributes);
171  state.addAttribute(
172  "operandSegmentSizes",
173  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
174  static_cast<int32_t>(outputs.size())}));
175 
176  // Create and fill the region of the structured operation.
177  Region &region = *state.addRegion();
178  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
179  state.attributes.getAttrs(), regionBuilder);
180 }
181 
182 /// Common parsing used for both named structured ops created by ods-gen and by
183 /// manually defined C++ ops. Does not handle regions.
184 static ParseResult
186  SmallVectorImpl<Type> &inputTypes,
187  SmallVectorImpl<Type> &outputTypes,
188  bool addOperandSegmentSizes = true) {
189  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
191  outputsOperands;
192 
193  if (succeeded(parser.parseOptionalLess())) {
194  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
195  return failure();
196  }
197  attrsLoc = parser.getCurrentLocation();
198  if (parser.parseOptionalAttrDict(result.attributes))
199  return failure();
200 
201  if (succeeded(parser.parseOptionalKeyword("ins"))) {
202  if (parser.parseLParen())
203  return failure();
204 
205  inputsOperandsLoc = parser.getCurrentLocation();
206  if (parser.parseOperandList(inputsOperands) ||
207  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
208  return failure();
209  }
210 
211  if (succeeded(parser.parseOptionalKeyword("outs"))) {
212  outputsOperandsLoc = parser.getCurrentLocation();
213  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
214  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
215  return failure();
216  }
217 
218  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
219  result.operands) ||
220  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
221  result.operands))
222  return failure();
223 
224  if (addOperandSegmentSizes) {
225  // This is a bit complex because we're trying to be backward compatible with
226  // operation syntax that mix the inherent attributes and the discardable
227  // ones in the same dictionary. If the properties are used, we append the
228  // operandSegmentSizes there directly. Otherwise we append it to the
229  // discardable attributes dictionary where it is handled by the generic
230  // Operation::create(...) method.
231  if (result.propertiesAttr) {
232  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
233  attrs.append("operandSegmentSizes",
235  {static_cast<int32_t>(inputsOperands.size()),
236  static_cast<int32_t>(outputsOperands.size())}));
237  result.propertiesAttr = attrs.getDictionary(parser.getContext());
238  } else {
239  result.addAttribute("operandSegmentSizes",
241  {static_cast<int32_t>(inputsOperands.size()),
242  static_cast<int32_t>(outputsOperands.size())}));
243  }
244  }
245  if (!result.propertiesAttr) {
246  std::optional<RegisteredOperationName> info =
247  result.name.getRegisteredInfo();
248  if (info) {
249  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
250  return parser.emitError(attrsLoc)
251  << "'" << result.name.getStringRef() << "' op ";
252  })))
253  return failure();
254  }
255  }
256  return success();
257 }
258 
260  ValueRange outputs) {
261  if (!inputs.empty())
262  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
263  if (!outputs.empty())
264  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Specific parsing and printing for named structured ops created by ods-gen.
269 //===----------------------------------------------------------------------===//
270 
272  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
273  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
274  RegionBuilderFn regionBuilder) {
275  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
276  return parser.emitError(
277  parser.getCurrentLocation(),
278  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
279  "region expects {0} args, got {1}",
280  numRegionArgs, inputTypes.size() + outputTypes.size()));
281  }
282 
283  OpBuilder opBuilder(parser.getContext());
284  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
285  regionBuilder);
286  return success();
287 }
288 
289 static ParseResult
291  SmallVectorImpl<Type> &resultTypes) {
292  if (parser.parseOptionalArrowTypeList(resultTypes))
293  return failure();
294  return success();
295 }
296 
298  OperationState &result,
299  unsigned numRegionArgs,
300  RegionBuilderFn regionBuilder) {
301  // TODO: Enable when ods-gen supports captures.
302  SmallVector<Type, 1> inputTypes, outputTypes;
303  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
304  return failure();
305 
306  // TODO: consider merging results parsing into region parsing.
307  // Need to wait for declarative assembly resolution to decide.
308  SmallVector<Type, 1> outputTensorsTypes;
309  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
310  return failure();
311  result.addTypes(outputTensorsTypes);
312 
313  std::unique_ptr<Region> region = std::make_unique<Region>();
314  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
315  outputTypes, result.attributes.getAttrs(),
316  regionBuilder))
317  return failure();
318  result.addRegion(std::move(region));
319 
320  return success();
321 }
322 
324  TypeRange resultTypes) {
325  if (resultTypes.empty())
326  return;
327  p.printOptionalArrowTypeList(resultTypes);
328 }
329 
331  ValueRange inputs, ValueRange outputs) {
333  op->getAttrs(),
334  /*elidedAttrs=*/{"operandSegmentSizes",
335  // See generated code in
336  // LinalgNamedStructuredOps.yamlgen.cpp.inc
337  "linalg.memoized_indexing_maps"});
338 
339  // Printing is shared with generic ops, except for the region and
340  // attributes.
341  printCommonStructuredOpParts(p, inputs, outputs);
342 
343  // Results printing.
345 
346  // Region is elided.
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // Region builder helper.
351 // TODO: Move this to a utility library.
352 // The public methods on this class are referenced directly from generated code.
353 // Helper build the unary, binary, and type conversion functions defined by the
354 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
355 // class.
356 //
357 // Implementations of the math functions must be polymorphic over numeric types,
358 // internally performing necessary casts. If the function application makes no
359 // sense, then the only recourse is to assert and return nullptr. This can be
360 // extended later if it becomes possible to fail construction of the region. The
361 // invariant should be enforced at a higher level.
362 //
363 // TODO: These helpers are currently type polymorphic over the class of integer
364 // and floating point types, but they will not internally cast within bit
365 // widths of a class (mixed precision such as i8->i32) or across classes
366 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
367 // to be handled with care and work is being considered to extend the op
368 // language to make such cases explicit. In the mean-time, violating this will
369 // fail verification, which is deemed acceptable.
370 //===----------------------------------------------------------------------===//
371 
372 namespace {
373 
374 class RegionBuilderHelper {
375 public:
376  RegionBuilderHelper(MLIRContext *context, Block &block)
377  : context(context), block(block) {}
378 
379  // Build the unary functions defined by OpDSL.
380  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
381  if (!isFloatingPoint(arg))
382  llvm_unreachable("unsupported non numeric type");
383  OpBuilder builder = getBuilder();
384  switch (unaryFn) {
385  case UnaryFn::exp:
386  return builder.create<math::ExpOp>(arg.getLoc(), arg);
387  case UnaryFn::log:
388  return builder.create<math::LogOp>(arg.getLoc(), arg);
389  case UnaryFn::abs:
390  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
391  case UnaryFn::ceil:
392  return builder.create<math::CeilOp>(arg.getLoc(), arg);
393  case UnaryFn::floor:
394  return builder.create<math::FloorOp>(arg.getLoc(), arg);
395  case UnaryFn::negf:
396  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
397  }
398  llvm_unreachable("unsupported unary function");
399  }
400 
401  // Build the binary functions defined by OpDSL.
402  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
403  bool allComplex = isComplex(arg0) && isComplex(arg1);
404  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
405  bool allInteger = isInteger(arg0) && isInteger(arg1);
406  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
407  arg1.getType().getIntOrFloatBitWidth() == 1;
408  if (!allComplex && !allFloatingPoint && !allInteger)
409  llvm_unreachable("unsupported non numeric type");
410  OpBuilder builder = getBuilder();
411  switch (binaryFn) {
412  case BinaryFn::add:
413  if (allComplex)
414  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
415  if (allFloatingPoint)
416  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
417  if (allBool)
418  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
419  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
420  case BinaryFn::sub:
421  if (allComplex)
422  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
423  if (allFloatingPoint)
424  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
425  if (allBool)
426  llvm_unreachable("unsupported operation: sub with bools");
427  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
428  case BinaryFn::mul:
429  if (allComplex)
430  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
431  if (allFloatingPoint)
432  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
433  if (allBool)
434  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
435  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
436  case BinaryFn::div:
437  if (allComplex)
438  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
439  if (allFloatingPoint)
440  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
441  if (allBool)
442  llvm_unreachable("unsupported operation: div with bools");
443  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
444  case BinaryFn::div_unsigned:
445  if (!allInteger || allBool)
446  llvm_unreachable("unsupported operation: unsigned div not on uint");
447  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
448  case BinaryFn::max_signed:
449  assert(!allComplex);
450  if (allFloatingPoint)
451  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
452  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
453  case BinaryFn::min_signed:
454  assert(!allComplex);
455  if (allFloatingPoint)
456  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
457  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
458  case BinaryFn::max_unsigned:
459  assert(!allComplex);
460  if (allFloatingPoint)
461  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
462  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
463  case BinaryFn::min_unsigned:
464  assert(!allComplex);
465  if (allFloatingPoint)
466  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
467  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
468  }
469  llvm_unreachable("unsupported binary function");
470  }
471 
472  // Build the type functions defined by OpDSL.
473  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
474  switch (typeFn) {
475  case TypeFn::cast_signed:
476  return cast(toType, operand, false);
477  case TypeFn::cast_unsigned:
478  return cast(toType, operand, true);
479  }
480  llvm_unreachable("unsupported type conversion function");
481  }
482 
483  void yieldOutputs(ValueRange values) {
484  OpBuilder builder = getBuilder();
485  Location loc = builder.getUnknownLoc();
486  builder.create<YieldOp>(loc, values);
487  }
488 
489  Value constant(const std::string &value) {
490  OpBuilder builder = getBuilder();
491  Location loc = builder.getUnknownLoc();
492  Attribute valueAttr = parseAttribute(value, builder.getContext());
493  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
494  }
495 
496  Value index(int64_t dim) {
497  OpBuilder builder = getBuilder();
498  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
499  }
500 
501  Type getIntegerType(unsigned width) {
502  return IntegerType::get(context, width);
503  }
504 
505  Type getFloat32Type() { return Float32Type::get(context); }
506  Type getFloat64Type() { return Float64Type::get(context); }
507 
508 private:
509  // Generates operations to cast the given operand to a specified type.
510  // If the cast cannot be performed, a warning will be issued and the
511  // operand returned as-is (which will presumably yield a verification
512  // issue downstream).
513  Value cast(Type toType, Value operand, bool isUnsignedCast) {
514  OpBuilder builder = getBuilder();
515  auto loc = operand.getLoc();
516  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
517  }
518 
519  bool isComplex(Value value) {
520  return llvm::isa<ComplexType>(value.getType());
521  }
522  bool isFloatingPoint(Value value) {
523  return llvm::isa<FloatType>(value.getType());
524  }
525  bool isInteger(Value value) {
526  return llvm::isa<IntegerType>(value.getType());
527  }
528 
529  OpBuilder getBuilder() {
530  OpBuilder builder(context);
531  builder.setInsertionPointToEnd(&block);
532  return builder;
533  }
534 
535  MLIRContext *context;
536  Block &block;
537 };
538 
539 } // namespace
540 
541 //===----------------------------------------------------------------------===//
542 // CopyOp
543 //===----------------------------------------------------------------------===//
544 
545 namespace {
546 
547 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
549  LogicalResult matchAndRewrite(CopyOp copyOp,
550  PatternRewriter &rewriter) const override {
551  if (copyOp.getInputs() != copyOp.getOutputs())
552  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
553  if (copyOp.hasBufferSemantics())
554  rewriter.eraseOp(copyOp);
555  else
556  rewriter.replaceOp(copyOp, copyOp.getInputs());
557 
558  return success();
559  }
560 };
561 
562 } // namespace
563 
564 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
565  MLIRContext *context) {
566  results.add<EraseSelfCopy>(context);
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // FillOp
571 //===----------------------------------------------------------------------===//
572 
573 namespace {
574 
575 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
576 ///
577 /// For such op chains, we can create new linalg.fill ops with the result
578 /// type of the tensor.expand/collapse_shape op.
579 template <typename TensorReshapeOp>
580 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
582  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
583  PatternRewriter &rewriter) const override {
584  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
585  if (!oldFill)
586  return failure();
587 
588  Location loc = oldFill.getLoc();
589  auto newInit = rewriter.create<TensorReshapeOp>(
590  loc, reshapeOp.getResultType(), oldFill.output(),
591  reshapeOp.getReassociation());
592  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
593  ValueRange{newInit});
594 
595  return success();
596  }
597 };
598 
599 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
600 /// filling value are the same.
601 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
603 
604  LogicalResult matchAndRewrite(tensor::PadOp padOp,
605  PatternRewriter &rewriter) const override {
606  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
607  if (!fillOp)
608  return failure();
609 
610  // We can only fold if the padding value is the same as the original
611  // filling value.
612  Value padValue = padOp.getConstantPaddingValue();
613  if (!padValue || fillOp.value() != padValue)
614  return failure();
615 
616  ReifiedRankedShapedTypeDims reifiedShape;
617  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
618  return rewriter.notifyMatchFailure(
619  padOp, "failed to reify tensor.pad op result shape");
620 
621  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
622  padOp.getLoc(), reifiedShape.front(),
623  padOp.getResultType().getElementType());
624  Value replacement =
625  rewriter
626  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
627  ValueRange{emptyTensor})
628  .getResult(0);
629  if (replacement.getType() != padOp.getResultType()) {
630  replacement = rewriter.create<tensor::CastOp>(
631  fillOp.getLoc(), padOp.getResultType(), replacement);
632  }
633  rewriter.replaceOp(padOp, replacement);
634  return success();
635  }
636 };
637 
638 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
639 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
640 /// filling value are the same.
641 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
643 
644  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
645  PatternRewriter &rewriter) const override {
646  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
647  if (!srcPadOp)
648  return failure();
649 
650  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
651  return failure();
652 
653  // Walk back the tensor.insert_slice chain and find the first destination
654  // value at the start of the chain.
655  Value firstDest = insertOp.getDest();
656  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
657  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
658  return failure();
659 
660  // Make sure the range of values accessed are disjoint. Without this, we
661  // cannot fold tensor.pad away.
662  bool disjoint = false;
663  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
664  // If the dimension has dynamic offset/size, we cannot guarantee
665  // disjoint. So just skip it.
666  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
667  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
668  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
669  continue;
670 
671  // Get the range start and end, inclusively for both.
672  int64_t prevStart = prevOp.getStaticOffset(i);
673  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
674  prevOp.getStaticStride(i);
675  int64_t nextStart = insertOp.getStaticOffset(i);
676  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
677  insertOp.getStaticStride(i);
678  if (prevEnd < nextStart || nextEnd < prevStart) {
679  disjoint = true;
680  break;
681  }
682  }
683 
684  if (!disjoint)
685  break;
686  firstDest = prevOp.getDest();
687  }
688 
689  // Check whether the first destination is a fill op. For overlapped cases,
690  // this also cannot be true.
691  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
692  if (!dstFillOp)
693  return failure();
694 
695  // We can only fold if the padding value is the same as the original
696  // filling value.
697  Value padValue = srcPadOp.getConstantPaddingValue();
698  if (!padValue || dstFillOp.value() != padValue)
699  return failure();
700 
701  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
702  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
703 
704  Location loc = insertOp.getLoc();
705  MLIRContext *context = getContext();
706 
707  AffineExpr sym0, sym1;
708  bindSymbols(context, sym0, sym1);
709  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
710 
711  // Calculate the new offsets for the insert. It should be the old offsets
712  // plus low padding sizes.
713  SmallVector<OpFoldResult, 4> newOffsets;
714  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
715  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
716  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
717  }
718 
720  for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
721  newSizes.push_back(
722  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
723  .getResult());
724  }
725 
726  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
727  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
728  newSizes, insertOp.getMixedStrides());
729  return success();
730  }
731 };
732 
733 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
734 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
735 public:
737 
738  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
739  PatternRewriter &rewriter) const override {
740  // See if tensor input of tensor.extract op is the result of a linalg.fill
741  // op.
742  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
743  if (!fillOp)
744  return failure();
745 
746  // Get scalar input operand of linalg.fill op.
747  Value extractedScalar = fillOp.getInputs()[0];
748 
749  // Replace tensor.extract op with scalar value used to fill the tensor.
750  rewriter.replaceOp(extractOp, extractedScalar);
751  return success();
752  }
753 };
754 
755 /// Folds pack(fill) into a single fill op if
756 /// 1. The pack op does not have padding value, or
757 /// 2. The filled value and padding value are the same.
758 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
759  tensor::PackOp packOp) {
760  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
761  if (!fillOp)
762  return failure();
763 
764  if (auto paddingValue = packOp.getPaddingValue())
765  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
766  return failure();
767 
768  Value packOpDest = packOp.getDest();
769  if (!packOpDest.hasOneUse())
770  return failure();
771 
772  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
773  packOp.getDest());
774 }
775 
776 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
777 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
778 public:
779  FoldFillWithPack(MLIRContext *context)
780  : OpRewritePattern<tensor::PackOp>(context) {}
781 
782  LogicalResult matchAndRewrite(tensor::PackOp packOp,
783  PatternRewriter &rewriter) const override {
784  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
785  if (failed(fillOp))
786  return failure();
787  rewriter.replaceOp(packOp, fillOp.value().result());
788  return success();
789  }
790 };
791 
792 /// Fold fill with copy.
793 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
795 
796  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
797  PatternRewriter &rewriter) const override {
798  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
799  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
800  fillOp.getInputs(),
801  copyOp.getOutputs());
802  return success();
803  }
804  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
805  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
806  fillOp.getOutputs());
807  return success();
808  }
809  return failure();
810  }
811 };
812 
813 } // namespace
814 
815 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
816  MLIRContext *context) {
817  results
818  .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
819  FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
820  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
821  FoldInsertPadIntoFill>(context);
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // GenericOp
826 //===----------------------------------------------------------------------===//
827 
828 static void buildGenericRegion(
829  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
830  ValueRange outputs,
831  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
832  SmallVector<Type, 4> blockArgTypes;
833  SmallVector<Location, 4> blockArgLocs;
834  for (ValueRange container : {inputs, outputs}) {
835  for (Value v : container) {
836  Type t = v.getType();
837  blockArgTypes.push_back(
838  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
839  blockArgLocs.push_back(v.getLoc());
840  }
841  }
842 
843  OpBuilder::InsertionGuard guard(builder);
844  Block *bodyBlock =
845  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
846  bodyBuild(builder, loc, bodyBlock->getArguments());
847 }
848 
849 void GenericOp::getAsmBlockArgumentNames(Region &region,
850  OpAsmSetValueNameFn setNameFn) {
851  for (Value v : getRegionInputArgs())
852  setNameFn(v, "in");
853  for (Value v : getRegionOutputArgs())
854  setNameFn(v, "out");
855 }
856 
857 void GenericOp::build(
858  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
859  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
860  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
861  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
862  ArrayRef<NamedAttribute> attributes) {
863  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
864  iteratorTypes, doc, libraryCall);
865  result.addAttributes(attributes);
866  if (bodyBuild)
867  buildGenericRegion(builder, result.location, *result.regions.front(),
868  inputs, outputs, bodyBuild);
869 }
870 
871 void GenericOp::build(
872  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
873  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
874  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
875  StringRef libraryCall,
876  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
877  ArrayRef<NamedAttribute> attributes) {
878  build(builder, result, resultTensorTypes, inputs, outputs,
879  builder.getAffineMapArrayAttr(indexingMaps),
880  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
881  iteratorTypes,
882  [&](utils::IteratorType iter) -> mlir::Attribute {
883  return IteratorTypeAttr::get(builder.getContext(), iter);
884  }))),
885  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
886  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
887  bodyBuild, attributes);
888 }
889 
890 void GenericOp::build(
891  OpBuilder &builder, OperationState &result, ValueRange inputs,
892  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
893  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
894  StringRef libraryCall,
895  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
896  ArrayRef<NamedAttribute> attributes) {
897  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
898  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
899 }
900 
901 void GenericOp::build(
902  OpBuilder &builder, OperationState &result, ValueRange inputs,
903  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
904  ArrayRef<utils::IteratorType> iteratorTypes,
905  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
906  ArrayRef<NamedAttribute> attributes) {
907  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
908  /*doc=*/"",
909  /*libraryCall=*/"", bodyBuild, attributes);
910 }
911 
912 void GenericOp::build(
913  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
914  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
915  ArrayRef<utils::IteratorType> iteratorTypes,
916  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
917  ArrayRef<NamedAttribute> attributes) {
918  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
919  iteratorTypes,
920  /*doc=*/"",
921  /*libraryCall=*/"", bodyBuild, attributes);
922 }
923 
925  p << " ";
926 
927  // Print extra attributes.
928  auto genericAttrNames = linalgTraitAttrNames();
929 
930  llvm::StringSet<> genericAttrNamesSet;
931  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
932  SmallVector<NamedAttribute, 8> genericAttrs;
933  for (auto attr : (*this)->getAttrs()) {
934  if (attr.getName() == getIteratorTypesAttrName()) {
935  auto iteratorTypes =
936  llvm::cast<ArrayAttr>(attr.getValue())
937  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
938  // Convert IteratorType enums into the string representation. This is
939  // needed, because tests still use the old format when 'iterator_types'
940  // attribute is represented as an array of strings.
941  // TODO: Remove this conversion once tests are fixed.
942  SmallVector<Attribute> iteratorTypeNames =
943  llvm::to_vector(llvm::map_range(
944  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
945  return StringAttr::get(getContext(), stringifyIteratorType(t));
946  }));
947 
948  genericAttrs.emplace_back(
949  getIteratorTypesAttrName(),
950  ArrayAttr::get(getContext(), iteratorTypeNames));
951  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
952  genericAttrs.push_back(attr);
953  }
954  }
955  if (!genericAttrs.empty()) {
956  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
957  p << genericDictAttr;
958  }
959 
960  // Printing is shared with named ops, except for the region and attributes
961  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
962 
963  genericAttrNames.push_back("operandSegmentSizes");
964  genericAttrNamesSet.insert(genericAttrNames.back());
965 
966  bool hasExtraAttrs = false;
967  for (NamedAttribute n : (*this)->getAttrs()) {
968  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
969  break;
970  }
971  if (hasExtraAttrs) {
972  p << " attrs = ";
973  p.printOptionalAttrDict((*this)->getAttrs(),
974  /*elidedAttrs=*/genericAttrNames);
975  }
976 
977  // Print region.
978  if (!getRegion().empty()) {
979  p << ' ';
980  p.printRegion(getRegion());
981  }
982 
983  // Print results.
984  printNamedStructuredOpResults(p, getResultTensors().getTypes());
985 }
986 
988  DictionaryAttr dictAttr;
989  // Parse the core linalg traits that must check into a dictAttr.
990  // The name is unimportant as we will overwrite result.attributes.
991  // The core linalg traits must contain the information necessary to pass the
992  // verifier.
993  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
994  if (parser.parseAttribute(dictAttr, "_", result.attributes))
995  return failure();
996  result.attributes.assign(dictAttr.getValue().begin(),
997  dictAttr.getValue().end());
998 
999  // Convert array of string into an array of IteratorType enums. This is
1000  // needed, because tests still use the old format when 'iterator_types'
1001  // attribute is represented as an array of strings.
1002  // TODO: Remove this conversion once tests are fixed.
1003  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1004  result.attributes.get(getIteratorTypesAttrName(result.name)));
1005  if (!iteratorTypes) {
1006  return parser.emitError(attributeLocation)
1007  << "expected " << getIteratorTypesAttrName(result.name)
1008  << " array attribute";
1009  }
1010 
1011  SmallVector<Attribute> iteratorTypeAttrs;
1012 
1013  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1014  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1015  if (!maybeIteratorType.has_value())
1016  return parser.emitError(parser.getCurrentLocation())
1017  << "unexpected iterator_type (" << s << ")";
1018 
1019  iteratorTypeAttrs.push_back(
1020  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1021  }
1022  result.attributes.set(getIteratorTypesAttrName(result.name),
1023  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1024 
1025  // Parsing is shared with named ops, except for the region.
1026  SmallVector<Type, 1> inputTypes, outputTypes;
1027  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1028  return failure();
1029 
1030  // Optional attributes may be added.
1031  if (succeeded(parser.parseOptionalKeyword("attrs")))
1032  if (failed(parser.parseEqual()) ||
1033  failed(parser.parseOptionalAttrDict(result.attributes)))
1034  return failure();
1035 
1036  std::unique_ptr<Region> region = std::make_unique<Region>();
1037  if (parser.parseRegion(*region, {}))
1038  return failure();
1039  result.addRegion(std::move(region));
1040 
1041  // Generic ops may specify that a subset of its outputs are tensors. Such
1042  // outputs are specified in the result type.
1043  // TODO: may need to move output parsing before region parsing.
1044  // Need to wait for declarative assembly resolution to decide.
1045  SmallVector<Type, 1> outputTensorsTypes;
1046  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1047  return failure();
1048  result.addTypes(outputTensorsTypes);
1049 
1050  return success();
1051 }
1052 
1055  &effects,
1056  ValueRange results, const ValueRange inputOperands,
1057  ValueRange outputOperands) {
1058  for (auto operand : inputOperands) {
1059  if (!llvm::isa<MemRefType>(operand.getType()))
1060  continue;
1061  effects.emplace_back(MemoryEffects::Read::get(), operand,
1063  }
1064  for (auto operand : outputOperands) {
1065  if (!llvm::isa<MemRefType>(operand.getType()))
1066  continue;
1067  effects.emplace_back(MemoryEffects::Read::get(), operand,
1069  effects.emplace_back(MemoryEffects::Write::get(), operand,
1071  }
1072 }
1073 
1074 void GenericOp::getEffects(
1076  &effects) {
1077  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1078  getDpsInits());
1079 }
1080 
1081 LogicalResult GenericOp::verify() { return success(); }
1082 
1083 namespace {
1084 
1085 /// Remove generic operations (on tensors) that are just copying
1086 /// the values from inputs to the results. Requirements are
1087 /// 1) All iterator types are parallel
1088 /// 2) The body contains just a yield operation with the yielded values being
1089 /// the arguments corresponding to the operands.
1090 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1092 
1093  LogicalResult matchAndRewrite(GenericOp genericOp,
1094  PatternRewriter &rewriter) const override {
1095  // Check all indexing maps are identity.
1096  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1097  [](AffineMap map) { return !map.isIdentity(); }))
1098  return failure();
1099 
1100  // Check that the body of the linalg operation is just a linalg.yield
1101  // operation.
1102  Block &body = genericOp.getRegion().front();
1103  if (!llvm::hasSingleElement(body))
1104  return failure();
1105  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1106  if (!yieldOp)
1107  return failure();
1108 
1109  // In the buffer case, we need to check exact buffer equality.
1110  if (genericOp.hasBufferSemantics()) {
1111  if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
1112  genericOp.getDpsInputOperand(0)->get() ==
1113  genericOp.getDpsInitOperand(0)->get()) {
1114  rewriter.eraseOp(genericOp);
1115  return success();
1116  }
1117  return failure();
1118  }
1119 
1120  // Mixed semantics is not supported yet.
1121  if (!genericOp.hasTensorSemantics())
1122  return failure();
1123 
1124  // Get the argument number of the returned values. That is the operand
1125  // number to use for replacing uses of this operation.
1126  SmallVector<Value> returnedArgs;
1127  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1128  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1129  if (!yieldArg || yieldArg.getOwner() != &body)
1130  return failure();
1131  unsigned argumentNumber = yieldArg.getArgNumber();
1132  Value returnedArg = genericOp->getOperand(argumentNumber);
1133  Type resultType = genericOp->getResult(yieldVal.index()).getType();
1134  // The input can have a different type than the result, e.g. a dynamic
1135  // input dimension can be turned into a static output dimension.
1136  Type returnType = returnedArg.getType();
1137  if (returnType != resultType) {
1138  // Distinguish between sparse conversion or dense tensor casting.
1139  // TODO: unify the two ops?
1140  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1142  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1143  genericOp.getLoc(), resultType, returnedArg);
1144  else {
1145  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1146  resultType))
1147  return failure();
1148  returnedArg = rewriter.create<tensor::CastOp>(
1149  genericOp.getLoc(), resultType, returnedArg);
1150  }
1151  }
1152  returnedArgs.push_back(returnedArg);
1153  }
1154 
1155  if (returnedArgs.size() != genericOp->getNumResults())
1156  return failure();
1157  rewriter.replaceOp(genericOp, returnedArgs);
1158  return success();
1159  }
1160 };
1161 
1162 } // namespace
1163 
1164 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1165  MLIRContext *context) {
1166  results.add<EraseIdentityGenericOp>(context);
1167 }
1168 
1169 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1170  return memref::foldMemRefCast(*this);
1171 }
1172 
1173 //===----------------------------------------------------------------------===//
1174 // MapOp
1175 //===----------------------------------------------------------------------===//
1176 
1178  OpAsmParser &parser, OperationState &result,
1179  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1180  nullptr) {
1181  // Parse `ins` and `outs`.
1182  SmallVector<Type, 4> inputTypes, outputTypes;
1183  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1184  /*addOperandSegmentSizes=*/false))
1185  return failure();
1186 
1187  // Add result types.
1188  for (Type outputType : outputTypes) {
1189  if (llvm::isa<RankedTensorType>(outputType))
1190  result.addTypes(outputType);
1191  }
1192 
1193  // Parse required attributes.
1194  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1195  return failure();
1196 
1197  // Parse optional attributes.
1198  if (parser.parseOptionalAttrDict(result.attributes))
1199  return failure();
1200  return success();
1201 }
1202 
1203 void MapOp::getAsmBlockArgumentNames(Region &region,
1204  OpAsmSetValueNameFn setNameFn) {
1205  for (Value v : getRegionInputArgs())
1206  setNameFn(v, "in");
1207 }
1208 
1209 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1210  if (!getResults().empty())
1211  setNameFn(getResults().front(), "mapped");
1212 }
1213 
1214 void MapOp::build(
1215  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1216  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1217  ArrayRef<NamedAttribute> attributes) {
1218  build(builder, result, TypeRange{}, inputs, init);
1219  result.addAttributes(attributes);
1220 
1221  // Add output types for `RankedTensorType` output arguments.
1222  Type initType = init.getType();
1223  if (llvm::isa<RankedTensorType>(initType))
1224  result.addTypes(initType);
1225 
1226  if (bodyBuild)
1227  buildGenericRegion(builder, result.location, *result.regions.front(),
1228  inputs, /*outputs=*/{}, bodyBuild);
1229 }
1230 
1231 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1232  const OperationName &payloadOpName,
1233  const NamedAttrList &payloadOpAttrs,
1234  ArrayRef<Value> operands,
1235  bool initFirst = false) {
1236  OpBuilder b(parser.getContext());
1237  Region *body = result.addRegion();
1238  Block &block = body->emplaceBlock();
1239  b.setInsertionPointToStart(&block);
1240  SmallVector<Value> bbArgs;
1241  for (auto &operand : operands) {
1242  block.addArgument(
1243  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1244  b.getUnknownLoc());
1245  }
1246  SmallVector<Value> payloadOpOperands;
1247  // If initFirst flag is enabled, we consider init as the first position of
1248  // payload operands.
1249  if (initFirst) {
1250  payloadOpOperands.push_back(block.getArguments().back());
1251  for (const auto &arg : block.getArguments().drop_back())
1252  payloadOpOperands.push_back(arg);
1253  } else {
1254  payloadOpOperands = {block.getArguments().begin(),
1255  block.getArguments().end()};
1256  }
1257 
1258  Operation *payloadOp = b.create(
1259  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1260  payloadOpOperands,
1261  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1262  .getElementType()},
1263  payloadOpAttrs);
1264  b.create<YieldOp>(result.location, payloadOp->getResults());
1265 }
1266 
1268  std::optional<OperationName> payloadOpName;
1269  NamedAttrList payloadOpAttrs;
1270  if (succeeded(parser.parseOptionalLBrace())) {
1271  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1272  if (failed(operationName))
1273  return failure();
1274  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1275  return failure();
1276  payloadOpName = operationName.value();
1277  if (parser.parseRBrace())
1278  return failure();
1279  }
1280 
1281  if (parseDstStyleOp(parser, result))
1282  return failure();
1283 
1284  if (payloadOpName.has_value()) {
1285  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1286  ArrayRef(result.operands).drop_back());
1287  } else {
1289  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1290  /*allowType=*/true, /*allowAttrs=*/true)) {
1291  return failure();
1292  }
1293  Region *body = result.addRegion();
1294  if (parser.parseRegion(*body, regionArgs))
1295  return failure();
1296  }
1297  return success();
1298 }
1299 
1300 // Retrieve the operation from the body, if it is the only one (except
1301 // yield) and if it gets the same amount of arguments as the body does.
1302 // If initFirst flag is enabled, we check that init takes the first position in
1303 // operands of payload.
1304 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1305  if (body->getOperations().size() != 2)
1306  return nullptr;
1307  Operation &payload = body->getOperations().front();
1308  assert(isa<YieldOp>(body->getOperations().back()));
1309 
1310  if (payload.getNumOperands() == 0 ||
1311  payload.getNumOperands() != body->getNumArguments())
1312  return nullptr;
1313  if (initFirst) {
1314  // check init
1315  if (payload.getOperands().back() != body->getArgument(0))
1316  return nullptr;
1317  // check rest
1318  for (const auto &[operand, bbArg] :
1319  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1320  if (bbArg != operand)
1321  return nullptr;
1322  }
1323  } else {
1324  for (const auto &[operand, bbArg] :
1325  llvm::zip(payload.getOperands(), body->getArguments())) {
1326  if (bbArg != operand)
1327  return nullptr;
1328  }
1329  }
1330  return &payload;
1331 }
1332 
1333 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1334  SmallVector<StringRef> elidedAttrs;
1335  std::string attrToElide;
1336  p << " { " << payloadOp->getName().getStringRef();
1337  for (const auto &attr : payloadOp->getAttrs()) {
1338  auto fastAttr =
1339  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1340  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1341  attrToElide = attr.getName().str();
1342  elidedAttrs.push_back(attrToElide);
1343  break;
1344  }
1345  }
1346  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1347  p << " }";
1348 }
1349 
1350 void MapOp::print(OpAsmPrinter &p) {
1351  Block *mapper = getBody();
1352  Operation *payloadOp = findPayloadOp(mapper);
1353  if (payloadOp) {
1354  printShortForm(p, payloadOp);
1355  }
1356 
1357  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1358  p.printOptionalAttrDict((*this)->getAttrs());
1359 
1360  if (!payloadOp) {
1361  // Print region if the payload op was not detected.
1362  p.increaseIndent();
1363  p.printNewline();
1364  p << "(";
1365  llvm::interleaveComma(mapper->getArguments(), p,
1366  [&](auto arg) { p.printRegionArgument(arg); });
1367  p << ") ";
1368 
1369  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1370  p.decreaseIndent();
1371  }
1372 }
1373 
1375  auto *bodyBlock = getBody();
1376  auto blockArgs = bodyBlock->getArguments();
1377 
1378  // Checks if the number of `inputs` match the arity of the `mapper` region.
1379  if (getInputs().size() != blockArgs.size())
1380  return emitOpError() << "expects number of operands to match the arity of "
1381  "mapper, but got: "
1382  << getInputs().size() << " and " << blockArgs.size();
1383 
1384  // The parameters of mapper should all match the element type of inputs.
1385  for (const auto &[bbArgType, inputArg] :
1386  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1387  auto inputElemType =
1388  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1389  if (bbArgType != inputElemType) {
1390  return emitOpError() << "expected element type of input " << inputElemType
1391  << " to match bbArg type " << bbArgType;
1392  }
1393  }
1394 
1395  // The shape of each input must match the shape of the output.
1396  auto outputShape = getInit().getType().getShape();
1397  for (Type inputArgType : TypeRange{getInputs()}) {
1398  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1399  if (inputElemShape != outputShape) {
1400  return emitOpError() << "expected shape of input (" << inputElemShape
1401  << ") to match shape of output (" << outputShape
1402  << ")";
1403  }
1404  }
1405 
1406  return success();
1407 }
1408 
1409 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1410  int64_t rank = getInit().getType().getRank();
1411  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1412 }
1413 
1414 ArrayAttr MapOp::getIndexingMaps() {
1415  Builder builder(getContext());
1416  int64_t rank = getInit().getType().getRank();
1417  int64_t numIndexingMaps = getOperands().size();
1419  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1420 }
1421 
1422 void MapOp::getEffects(
1424  &effects) {
1425  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1426  getDpsInits());
1427 }
1428 
1429 //===----------------------------------------------------------------------===//
1430 // ReduceOp
1431 //===----------------------------------------------------------------------===//
1432 
1433 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1434  OpAsmSetValueNameFn setNameFn) {
1435  for (Value v : getRegionInputArgs())
1436  setNameFn(v, "in");
1437  for (Value v : getRegionOutputArgs())
1438  setNameFn(v, "init");
1439 }
1440 
1441 void ReduceOp::getAsmResultNames(
1442  function_ref<void(Value, StringRef)> setNameFn) {
1443  if (!getResults().empty())
1444  setNameFn(getResults().front(), "reduced");
1445 }
1446 
1447 void ReduceOp::build(
1448  OpBuilder &builder, OperationState &result, ValueRange inputs,
1449  ValueRange inits, ArrayRef<int64_t> dimensions,
1450  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1451  ArrayRef<NamedAttribute> attributes) {
1452  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1453  result.addAttributes(attributes);
1454 
1455  // Add output types for `RankedTensorType` output arguments.
1456  for (Value init : inits) {
1457  Type initType = init.getType();
1458  if (llvm::isa<RankedTensorType>(initType))
1459  result.addTypes(initType);
1460  }
1461 
1462  if (bodyBuild)
1463  buildGenericRegion(builder, result.location, *result.regions.front(),
1464  inputs, inits, bodyBuild);
1465 }
1466 
1467 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1468  int64_t inputRank =
1469  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1470  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1471  utils::IteratorType::parallel);
1472  for (int64_t reductionDim : getDimensions())
1473  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1474  return iteratorTypes;
1475 }
1476 
1477 ArrayAttr ReduceOp::getIndexingMaps() {
1478  int64_t inputRank =
1479  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1480  SmallVector<AffineMap> affineMaps(
1481  getNumDpsInputs(),
1483  AffineMap resultMap =
1485  .dropResults(getDimensions());
1486  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1487  affineMaps.push_back(resultMap);
1488  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1489 }
1490 
1491 void ReduceOp::getEffects(
1493  &effects) {
1494  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1495  getDpsInits());
1496 }
1497 
1499  NamedAttrList &attributes,
1500  StringRef attributeName) {
1501  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1502  return failure();
1503 
1504  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1505  return success();
1506 }
1507 
1509  std::optional<OperationName> payloadOpName;
1510  NamedAttrList payloadOpAttrs;
1511  if (succeeded(parser.parseOptionalLBrace())) {
1512  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1513  if (failed(operationName))
1514  return failure();
1515  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1516  return failure();
1517  payloadOpName = operationName.value();
1518  if (parser.parseRBrace())
1519  return failure();
1520  }
1521 
1522  if (parseDstStyleOp(
1523  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1524  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1525  }))
1526  return failure();
1527 
1528  if (payloadOpName.has_value()) {
1529  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1530  ArrayRef(result.operands), /*initFirst=*/true);
1531  } else {
1533  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1534  /*allowType=*/true, /*allowAttrs=*/true)) {
1535  return failure();
1536  }
1537 
1538  Region *body = result.addRegion();
1539  if (parser.parseRegion(*body, regionArgs))
1540  return failure();
1541  }
1542 
1543  return success();
1544 }
1545 
1546 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1547  ArrayRef<int64_t> attributeValue) {
1548  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1549 }
1550 
1551 void ReduceOp::print(OpAsmPrinter &p) {
1552  Block *mapper = getBody();
1553  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1554  if (payloadOp) {
1555  printShortForm(p, payloadOp);
1556  }
1557 
1558  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1559  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1560  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1561  if (!payloadOp) {
1562  // Print region if the payload op was not detected.
1563  p.increaseIndent();
1564  p.printNewline();
1565  p << "(";
1566  llvm::interleaveComma(mapper->getArguments(), p,
1567  [&](auto arg) { p.printRegionArgument(arg); });
1568  p << ") ";
1569 
1570  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1571  p.decreaseIndent();
1572  }
1573 }
1574 
1576  ArrayRef<int64_t> dimensionsRef = getDimensions();
1577 
1578  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1579  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1580  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1581  return emitOpError() << "expects all inputs to have the same shapes. "
1582  "Shape at input-index "
1583  << i
1584  << " is not equal to the shape at input-index 0.";
1585  }
1586  }
1587  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1588  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1589  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1590  return emitOpError() << "expects all outputs to have the same shapes. "
1591  "Shape at output-index "
1592  << i
1593  << " is not equal to the shape at output-index 0.";
1594  }
1595  }
1596  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1597  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1598 
1599  DenseSet<int64_t> dimensionsToReduce;
1600  for (int64_t dimension : dimensionsRef) {
1601  if (dimension < 0 || dimension >= inputType.getRank()) {
1602  return emitOpError()
1603  << "dimensions for reduction should be in the range [0, "
1604  << inputType.getRank() - 1 << "].";
1605  }
1606  dimensionsToReduce.insert(dimension);
1607  }
1608 
1609  auto inputDims = inputType.getShape();
1610  auto initDims = initType.getShape();
1611 
1612  // Input dimensions that will be left after the reduction.
1613  SmallVector<int64_t> reducedInputDims;
1614  for (const auto &en : llvm::enumerate(inputDims)) {
1615  if (!dimensionsToReduce.count(en.index()))
1616  reducedInputDims.push_back(en.value());
1617  }
1618 
1619  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1620  return emitOpError() << "number of dimensions after reduction "
1621  << reducedInputDims.size()
1622  << " doesn't match the init rank "
1623  << initType.getRank();
1624  }
1625 
1626  if (reducedInputDims != initDims)
1627  return emitOpError() << "init dimensions [" << initDims
1628  << "] doesn't match input dimensions after reduction ["
1629  << reducedInputDims << "]";
1630 
1631  Block *block = getBody();
1632  if (block->getNumArguments() != this->getNumOperands())
1633  return emitOpError()
1634  << "mismatching number of operands and block arguments";
1635 
1636  // Check that the first block arguments match the element type of the inputs.
1637  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1638  Type inputElementType =
1639  llvm::cast<ShapedType>(input.getType()).getElementType();
1640  if (inputElementType != bbArg.getType())
1641  return emitOpError()
1642  << "input element type " << inputElementType
1643  << " does not match corresponding block argument type "
1644  << bbArg.getType();
1645  }
1646 
1647  // Check that the last block arguments match the element type of the outputs.
1648  for (auto [output, bbArg] : llvm::zip(
1649  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1650  auto outputElementType =
1651  llvm::cast<ShapedType>(output.getType()).getElementType();
1652  if (outputElementType != bbArg.getType())
1653  return emitOpError()
1654  << "output element type " << outputElementType
1655  << " does not match corresponding block argument type "
1656  << bbArg.getType();
1657  }
1658  return success();
1659 }
1660 
1661 //===----------------------------------------------------------------------===//
1662 // TransposeOp
1663 //===----------------------------------------------------------------------===//
1664 
1665 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1666  Region &region, ValueRange inputs,
1667  ValueRange outputs) {
1668  buildGenericRegion(builder, loc, region, inputs, outputs,
1669  [](OpBuilder &b, Location loc, ValueRange args) {
1670  b.create<linalg::YieldOp>(loc, args[0]);
1671  });
1672 }
1673 
1674 void TransposeOp::build(::mlir::OpBuilder &builder,
1675  ::mlir::OperationState &result, Value input, Value init,
1676  DenseI64ArrayAttr permutation,
1677  ArrayRef<NamedAttribute> attributes) {
1678  result.addOperands(input);
1679  result.addOperands(init);
1680  result.addAttribute(getPermutationAttrName(result.name), permutation);
1681  result.addAttributes(attributes);
1682 
1683  // Add output types for `RankedTensorType` output arguments.
1684  Type initType = init.getType();
1685  if (llvm::isa<RankedTensorType>(initType))
1686  result.addTypes(initType);
1687 
1688  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1689  init);
1690 }
1691 
1692 void TransposeOp::build(::mlir::OpBuilder &builder,
1693  ::mlir::OperationState &result, Value input, Value init,
1694  ArrayRef<int64_t> permutation,
1695  ArrayRef<NamedAttribute> attributes) {
1696  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1697  attributes);
1698 }
1699 
1701  if (failed(parseDstStyleOp(
1702  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1703  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1704  })))
1705  return failure();
1706 
1707  OpBuilder builder(parser.getContext());
1708  buildIdentityRegion(builder, result.location, *result.addRegion(),
1709  /*inputs=*/result.operands,
1710  /*outputs=*/{});
1711  return success();
1712 }
1713 
1714 void TransposeOp::getAsmResultNames(
1715  function_ref<void(Value, StringRef)> setNameFn) {
1716  if (!getResults().empty())
1717  setNameFn(getResults().front(), "transposed");
1718 }
1719 
1721  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1722  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1723  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1724 }
1725 
1727  ArrayRef<int64_t> permutationRef = getPermutation();
1728 
1729  if (!isPermutationVector(permutationRef))
1730  return emitOpError("permutation is not valid");
1731 
1732  auto inputType = getInput().getType();
1733  auto initType = getInit().getType();
1734 
1735  int64_t rank = inputType.getRank();
1736 
1737  if (rank != initType.getRank())
1738  return emitOpError() << "input rank " << rank
1739  << " does not match init rank " << initType.getRank();
1740 
1741  if (rank != static_cast<int64_t>(permutationRef.size()))
1742  return emitOpError() << "size of permutation " << permutationRef.size()
1743  << " does not match the argument rank " << rank;
1744 
1745  auto inputDims = inputType.getShape();
1746  auto initDims = initType.getShape();
1747 
1748  for (int64_t i = 0; i < rank; ++i) {
1749  int64_t inputDim = inputDims[permutationRef[i]];
1750  int64_t initDim = initDims[i];
1751 
1752  if (inputDim != initDim) {
1753  return emitOpError() << "dim(result, " << i << ") = " << initDim
1754  << " doesn't match dim(input, permutation[" << i
1755  << "]) = " << inputDim;
1756  }
1757  }
1758 
1759  return success();
1760 }
1761 
1762 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1763  int64_t rank = getInit().getType().getRank();
1764  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1765 }
1766 
1767 ArrayAttr TransposeOp::getIndexingMaps() {
1768  Builder builder(getContext());
1769  int64_t rank = getInit().getType().getRank();
1770  return builder.getAffineMapArrayAttr(
1771  {builder.getMultiDimIdentityMap(rank),
1773  llvm::to_vector_of<unsigned>(getPermutation()), getContext())});
1774 }
1775 
1776 void TransposeOp::getEffects(
1778  &effects) {
1779  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1780  getDpsInits());
1781 }
1782 
1783 //===----------------------------------------------------------------------===//
1784 // BroadcastOp
1785 //===----------------------------------------------------------------------===//
1786 
1787 void BroadcastOp::build(::mlir::OpBuilder &builder,
1788  ::mlir::OperationState &result, Value input, Value init,
1789  DenseI64ArrayAttr dimensions,
1790  ArrayRef<NamedAttribute> attributes) {
1791  result.addOperands(input);
1792  result.addOperands(init);
1793  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
1794  result.addAttributes(attributes);
1795 
1796  // Add output types for `RankedTensorType` output arguments.
1797  Type initType = init.getType();
1798  if (llvm::isa<RankedTensorType>(initType))
1799  result.addTypes(initType);
1800 
1801  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1802  init);
1803 }
1804 
1805 void BroadcastOp::build(::mlir::OpBuilder &builder,
1806  ::mlir::OperationState &result, Value input, Value init,
1807  ArrayRef<int64_t> dimensions,
1808  ArrayRef<NamedAttribute> attributes) {
1809  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
1810  attributes);
1811 }
1812 
1814  if (failed(parseDstStyleOp(
1815  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1816  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1817  })))
1818  return failure();
1819 
1820  OpBuilder builder(parser.getContext());
1821  buildIdentityRegion(builder, result.location, *result.addRegion(),
1822  /*inputs=*/result.operands,
1823  /*outputs=*/{});
1824  return success();
1825 }
1826 
1827 void BroadcastOp::getAsmResultNames(
1828  function_ref<void(Value, StringRef)> setNameFn) {
1829  if (!getResults().empty())
1830  setNameFn(getResults().front(), "broadcasted");
1831 }
1832 
1834  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1835  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1836  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1837 }
1838 
1840  ArrayRef<int64_t> dimensionsRef = getDimensions();
1841 
1842  auto inputType = getInput().getType();
1843  auto initType = getInit().getType();
1844 
1845  int64_t inputRank = inputType.getRank();
1846  int64_t initRank = initType.getRank();
1847 
1848  auto inputShape = inputType.getShape();
1849  auto initShape = initType.getShape();
1850 
1851  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
1852  return emitOpError() << "input rank plus added dimensions does not "
1853  "match init rank. input rank: "
1854  << inputRank
1855  << ", dimensions size: " << dimensionsRef.size()
1856  << ", init rank: " << initRank;
1857 
1858  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
1859  if (dim < 0 || dim >= initRank)
1860  return emitOpError() << "dimension " << idx
1861  << " is out of range. expected range: [0, "
1862  << initRank - 1 << "], got: " << dim;
1863  }
1864 
1865  // Mapping from input dims to init dims.
1866  SmallVector<int64_t> dimMap;
1867  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
1868  if (!llvm::is_contained(dimensionsRef, dim))
1869  dimMap.push_back(dim);
1870  }
1871 
1872  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
1873  // This dimensions is mapped from the input. Init and input dims should
1874  // match.
1875  if (inputShape[inputDimIdx] != initShape[initDimIdx])
1876  return emitOpError() << "input dim " << inputDimIdx
1877  << " should match init dim " << initDimIdx
1878  << ". input: " << inputShape[inputDimIdx]
1879  << ", init: " << initShape[initDimIdx];
1880  }
1881 
1882  return success();
1883 }
1884 
1885 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
1886  int64_t rank = getInit().getType().getRank();
1887  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1888 }
1889 
1890 ArrayAttr BroadcastOp::getIndexingMaps() {
1891  Builder builder(getContext());
1892  int64_t rank = getInit().getType().getRank();
1893  return builder.getAffineMapArrayAttr(
1894  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
1895  builder.getMultiDimIdentityMap(rank)});
1896 }
1897 
1898 void BroadcastOp::getEffects(
1900  &effects) {
1901  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1902  getDpsInits());
1903 }
1904 
1905 //===----------------------------------------------------------------------===//
1906 // YieldOp
1907 //===----------------------------------------------------------------------===//
1908 
1910  if (getNumOperands() > 0)
1911  p << ' ' << getOperands();
1912  p.printOptionalAttrDict((*this)->getAttrs());
1913  if (getNumOperands() > 0)
1914  p << " : " << getOperandTypes();
1915 }
1916 
1919  SmallVector<Type, 2> types;
1920  SMLoc loc = parser.getCurrentLocation();
1921  return failure(parser.parseOperandList(opInfo) ||
1922  parser.parseOptionalAttrDict(result.attributes) ||
1923  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1924  parser.resolveOperands(opInfo, types, loc, result.operands));
1925 }
1926 
1927 // Check the operand number and types must match the element types of the
1928 // LinalgOp interface's shaped operands.
1929 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1930  if (op.getNumOperands() != linalgOp.getNumDpsInits())
1931  return op.emitOpError("expected number of yield values (")
1932  << op.getNumOperands()
1933  << ") to match the number of inits / outs operands of the enclosing "
1934  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
1935 
1936  for (OpOperand &opOperand : op->getOpOperands()) {
1937  OpOperand *outputOperand =
1938  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
1939  Type elementType = outputOperand->get().getType();
1940  if (isa<MemRefType, RankedTensorType>(elementType))
1941  elementType = getElementTypeOrSelf(outputOperand->get().getType());
1942  if (opOperand.get().getType() != elementType)
1943  return op.emitOpError("type of yield operand ")
1944  << (opOperand.getOperandNumber() + 1) << " ("
1945  << opOperand.get().getType() << ") doesn't match "
1946  << "the element type of the enclosing linalg.generic op ("
1947  << elementType << ")";
1948  }
1949  return success();
1950 }
1951 
1953  auto *parentOp = (*this)->getParentOp();
1954  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1955  return emitOpError("expected single non-empty parent region");
1956 
1957  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1958  return verifyYield(*this, linalgOp);
1959 
1960  return emitOpError("expected parent op with LinalgOp interface");
1961 }
1962 
1963 //===----------------------------------------------------------------------===//
1964 // IndexOp
1965 //===----------------------------------------------------------------------===//
1966 
1968  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1969  if (!linalgOp)
1970  return emitOpError("expected parent op with LinalgOp interface");
1971  if (linalgOp.getNumLoops() <= getDim())
1972  return emitOpError("expected dim (")
1973  << getDim() << ") to be lower than the number of loops ("
1974  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
1975  return success();
1976 }
1977 
1978 /////// Operations corresponding to library calls defined with Tablegen ////////
1979 
1980 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1981 
1982 #define GET_OP_CLASSES
1983 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1984 
1985 #define GET_OP_CLASSES
1986 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1987 
1988 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
1989  unsigned rank,
1990  MLIRContext *context) {
1991  if (maybeMap)
1992  return *maybeMap;
1993  if (rank == 0)
1994  return AffineMap::get(context);
1995  return AffineMap::getMultiDimIdentityMap(rank, context);
1996 }
1997 
1999 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2000  MLIRContext *context) {
2002  res.reserve(num);
2003  for (unsigned i = 0; i < num; ++i)
2004  res.push_back(getAffineDimExpr(startIdx++, context));
2005  return res;
2006 }
2007 
2010  auto rangeA = llvm::make_range(a.begin(), a.end());
2011  auto rangeB = llvm::make_range(b.begin(), b.end());
2012  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2013  return llvm::to_vector<4>(concatRanges);
2014 }
2015 
2016 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2017  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2018  ss << "view";
2019  for (auto size : memref.getShape())
2020  if (size < 0)
2021  ss << "sx";
2022  else
2023  ss << size << "x";
2024  if (failed(appendMangledType(ss, memref.getElementType())))
2025  return failure();
2026  if (auto as = memref.getMemorySpace()) {
2027  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2028  ss << "as" << attr.getInt();
2029  else
2030  return failure();
2031  }
2032  return success();
2033  }
2034  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2035  ss << "vector";
2036  llvm::interleave(
2037  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2038  if (failed(appendMangledType(ss, vec.getElementType())))
2039  return failure();
2040  return success();
2041  } else if (t.isSignlessIntOrIndexOrFloat()) {
2042  ss << t;
2043  return success();
2044  }
2045  return failure();
2046 }
2047 
2049  assert(isa<LinalgOp>(op));
2050  std::string name(op->getName().getStringRef().str());
2051  std::string fun = "";
2052  for (NamedAttribute kv : op->getAttrs()) {
2053  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2054  fun = stringifyEnum(ufa.getValue()).str() + "_";
2055  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2056  fun = stringifyEnum(bfa.getValue()).str() + "_";
2057  }
2058  }
2059  name.reserve(128);
2060  std::replace(name.begin(), name.end(), '.', '_');
2061  llvm::raw_string_ostream ss(name);
2062  ss << "_" << fun;
2063  for (Type t : op->getOperandTypes()) {
2064  if (failed(appendMangledType(ss, t)))
2065  return std::string();
2066  ss << "_";
2067  }
2068  std::string res = ss.str();
2069  res.pop_back();
2070  return res;
2071 }
2072 
2073 //===----------------------------------------------------------------------===//
2074 // Canonicalizers and Folders.
2075 //===----------------------------------------------------------------------===//
2076 
2077 namespace {
2078 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2080 
2081  LogicalResult matchAndRewrite(LinalgOp op,
2082  PatternRewriter &rewriter) const override {
2083  for (OpOperand &opOperand : op->getOpOperands()) {
2084  // Linalg "inputs" may be either tensor or memref type.
2085  // tensor<0xelt_type> is a convention that may not always mean
2086  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2087  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2088  if (!mt)
2089  continue;
2090  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2091  rewriter.eraseOp(op);
2092  return success();
2093  }
2094  }
2095  return failure();
2096  }
2097 };
2098 
2099 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2100 /// result that is more static than the linalg op.
2101 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2103 
2104  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2105  PatternRewriter &rewriter) const override {
2106  if (!tensor::canFoldIntoProducerOp(castOp))
2107  return failure();
2108 
2109  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2110  if (!linalgOp)
2111  return failure();
2112 
2113  // Cast can be in conditionally reachable region, if which case folding will
2114  // generate invalid code. Only conservatively fold ops in same block for
2115  // now.
2116  if (castOp->getBlock() != linalgOp->getBlock())
2117  return failure();
2118 
2119  OpBuilder::InsertionGuard guard(rewriter);
2120  rewriter.setInsertionPoint(linalgOp);
2121 
2122  Location loc = linalgOp.getLoc();
2123  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2124  unsigned resultNumber = resultValue.getResultNumber();
2125  auto resultType =
2126  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2127  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2128  // going from a more dynamic shape to a less dynamic shape. If the producer
2129  // for this cast, i.e. producer of the out operand, is also an operation
2130  // that folds with tensor.cast consumer (like this pattern), the cast will
2131  // continue to propagate as far up the stack as it can go.
2132  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2133  Value newOperand =
2134  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2135  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2136  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2137  linalgOp.getDpsInits().end());
2138  outputOperands[resultNumber] = newOperand;
2139  newOperands.append(outputOperands.begin(), outputOperands.end());
2140 
2141  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2142  linalgOp->result_type_end());
2143  resultTypes[resultNumber] = resultType;
2144  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2145 
2146  // Create a tensor.cast operation back to the original type.
2147  Value castBack = rewriter.create<tensor::CastOp>(
2148  loc, resultValue.getType(), newOp->getResult(resultNumber));
2149 
2150  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2151  results[resultNumber] = castBack;
2152  rewriter.replaceOp(linalgOp, results);
2153  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2154  return success();
2155  }
2156 };
2157 
2158 /// For each of the operand in `operands` this function maps the static sizes of
2159 /// dimensions to their affine dim expressions.
2160 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2161  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2162  for (OpOperand &opOperand : operands) {
2163  if (linalgOp.isScalar(&opOperand))
2164  continue;
2165  Value src = opOperand.get();
2166  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2167  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2168 
2169  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2170  // `tensor.cast` operation and source of the cast operation has a static
2171  // shape, then assign it to the `sourceShape`.
2172  auto *parentOp = src.getDefiningOp();
2173  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2174  if (parentOp) {
2175  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2176  Value castSource = castOp.getSource();
2177  auto castSourceType =
2178  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2179  if (castSourceType && castSourceType.hasStaticShape())
2180  sourceShape = castSourceType.getShape();
2181  }
2182  }
2183 
2184  // If the source shape's dimension has a static shape, map the affine dim
2185  // expression to the known static size.
2186  for (unsigned i = 0; i < sourceShape.size(); i++) {
2187  if (sourceType.isDynamicDim(i))
2188  continue;
2189  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2190  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2191  }
2192  }
2193 }
2194 
2195 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2196 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2197 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2198 /// change then `changeNeeded` is false and same operand is added in the
2199 /// `newOperands` list.
2200 static void createNewOperandWithStaticSizes(
2201  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2202  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2203  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2204  bool &changeNeeded) {
2205  Value src = opOperand->get();
2206  newOperands.push_back(src);
2207  if (linalgOp.isScalar(opOperand))
2208  return;
2209  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2210  Type resultType = sourceType;
2211  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2212  resultTypes.push_back(resultType);
2213  return;
2214  }
2215  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2216  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2217  SmallVector<int64_t> newShape;
2218  // If operand is updated with new shape, `newOperandNeeded` will be
2219  // true.
2220  bool newOperandNeeded = false;
2221  for (unsigned i = 0; i < sourceShape.size(); i++) {
2222  int64_t dimShape = sourceShape[i];
2223  AffineExpr dimExpr = sourceMap.getResult(i);
2224  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2225  newShape.push_back(dimShape);
2226  continue;
2227  }
2228  // Dimension has a dynamic shape and corresponding affine dim
2229  // expression is present in the map. So assign the size for the
2230  // given affine dim expression to the dimension.
2231  newShape.push_back(affineExprToSize[dimExpr]);
2232  newOperandNeeded = true;
2233  }
2234  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2235  if (newOperandNeeded) {
2236  changeNeeded = true;
2237  // Get the new operand value given its size and element type by
2238  // casting it.
2239  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2240  unsigned index = opOperand->getOperandNumber();
2241  newOperands[index] = newOperand;
2242  }
2243  if (linalgOp.isDpsInit(opOperand))
2244  resultTypes.push_back(resultType);
2245 }
2246 
2247 /// Static shapes for the operands can be inferred if any one of the operands
2248 /// have a static shape. This can be done by referring to the affine dim
2249 /// expressions for the operand.
2250 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2252 
2253  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2254  PatternRewriter &rewriter) const override {
2255  if (!linalgOp.hasTensorSemantics())
2256  return failure();
2257 
2258  // Maps must be projected permutations.
2259  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2260  return !map.isProjectedPermutation();
2261  }))
2262  return failure();
2263 
2264  // Maps affine dim expressions to the static size of that dimension.
2265  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2266  Location loc = linalgOp.getLoc();
2267 
2268  // For each of the affine dim expression, check if the size is known. If
2269  // known add that in the map.
2270  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2271 
2272  SmallVector<Value> newOperands;
2273  SmallVector<Type> resultTypes;
2274 
2275  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2276  // change in their types.
2277  bool changeNeeded = false;
2278  newOperands.reserve(linalgOp->getNumOperands());
2279  resultTypes.reserve(linalgOp.getNumDpsInits());
2280 
2281  // Iterate over all the operands and update the static sizes.
2282  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2283  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2284  affineExprToSize, linalgOp, newOperands,
2285  resultTypes, changeNeeded);
2286  }
2287 
2288  // If the generic op has all the required static information, no
2289  // canonicalization needed.
2290  if (!changeNeeded)
2291  return failure();
2292 
2293  // Clone op.
2294  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2295  SmallVector<Value> replacements;
2296  replacements.reserve(newOp->getNumResults());
2297  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2298  Value newResult = std::get<1>(it);
2299  Value oldResult = std::get<0>(it);
2300  Type newType = newResult.getType();
2301  Type oldType = oldResult.getType();
2302  replacements.push_back(
2303  (newType != oldType)
2304  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2305  : newResult);
2306  }
2307  rewriter.replaceOp(linalgOp, replacements);
2308  return success();
2309  }
2310 };
2311 
2312 } // namespace
2313 
2314 // All named ops canonicalizers and folders are auto-generated in the
2315 // .cpp.inc.
2316 
2317 //===----------------------------------------------------------------------===//
2318 // SoftmaxOp
2319 //===----------------------------------------------------------------------===//
2320 
2322  ShapedType inputType = getInputOperandType();
2323  ShapedType outputType = getOutputOperandType();
2324 
2325  ArrayRef<int64_t> inputShape = inputType.getShape();
2326  ArrayRef<int64_t> outputShape = outputType.getShape();
2327  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2328  return emitOpError("incompatible output shape");
2329 
2330  int64_t inputRank = getInputOperandRank();
2331  int64_t dimension = getDimension();
2332  if ((dimension < 0) || (dimension >= inputRank))
2333  return emitOpError("incorrect dimension specified");
2334 
2335  return success();
2336 }
2337 
2338 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2339  int64_t operandRank = getInputOperandRank();
2340  SmallVector<Range> loopBounds(operandRank);
2341  Location loc = getLoc();
2342  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2343  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2344  Value source = getInput();
2345  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2346  loopBounds[dim].offset = zero;
2347  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2348  loopBounds[dim].stride = one;
2349  }
2350  return loopBounds;
2351 }
2352 
2353 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2354  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2355  utils::IteratorType::parallel);
2356  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2357  return iteratorTypes;
2358 }
2359 
2361 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2362  ArrayRef<OpFoldResult> offsets,
2363  ArrayRef<OpFoldResult> sizes) {
2364  int64_t rank = getInputOperandRank();
2365  auto oneAttr = builder.getI64IntegerAttr(1);
2366  SmallVector<OpFoldResult> strides(rank, oneAttr);
2367  SmallVector<Value> tiledOperands;
2368  tiledOperands.emplace_back(
2369  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2370  tiledOperands.emplace_back(
2371  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2372 
2373  SmallVector<Type, 4> resultTypes;
2374  if (hasTensorSemantics())
2375  resultTypes.push_back(tiledOperands[1].getType());
2376  Operation *tiledOp =
2377  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2378 
2379  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2380 }
2381 
2382 LogicalResult SoftmaxOp::getResultTilePosition(
2383  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2384  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2385  SmallVector<OpFoldResult> &resultSizes) {
2386  if (resultNumber == 0) {
2387  resultOffsets.assign(offsets.begin(), offsets.end());
2388  resultSizes.assign(sizes.begin(), sizes.end());
2389  return success();
2390  }
2391  return failure();
2392 }
2393 
2394 // cast(dynamic) -> static.
2395 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2396  return memref::foldMemRefCast(*this);
2397 }
2398 
2401  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2403  Location loc = getOperation()->getLoc();
2404  IRRewriter rewriter(b);
2405  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2406  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2407  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2408  if (!outputShapedType.isDynamicDim(dim)) {
2409  // Static dim: Return IntegerAttr.
2410  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2411  } else {
2412  // Dynamic dim: Return Value.
2413  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2414  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2415  }
2416  }
2417  reifiedReturnShapes.emplace_back(std::move(shapes));
2418  return success();
2419 }
2420 
2421 void SoftmaxOp::getEffects(
2423  &effects) {
2424  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2425  getDpsInits());
2426 }
2427 
2428 // Helper functions for softmax decomposition.
2429 // @{
2430 
2431 // Helper function to produce the iterator types (reduction or parallel) and
2432 // affine maps for the iterators used in the decomposition of softmax.
2433 // This method creates:
2434 // If allParallel == true:
2435 // - iterator type: {parallel, ..., parallel}
2436 // - affine maps:
2437 // -- identity with inputRank dimensions.
2438 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2439 // where N == inputRank.
2440 //
2441 // If allParallel == false:
2442 // - iterator type at dim(i) == parallel for i != \p dim and
2443 // dim(dim) == reduction.
2444 // - affine map:
2445 // -- identity with inputRank dimensions.
2446 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2447 // where N == inputRank.
2448 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2450  int64_t dim, bool allParallel = false) {
2451  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2452  utils::IteratorType::parallel);
2453  if (!allParallel)
2454  iteratorTypes[dim] = utils::IteratorType::reduction;
2455  MLIRContext *ctxt = builder.getContext();
2456  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2457  SmallVector<AffineExpr, 2> affineExprs;
2458  for (int i = 0; i < inputRank; i++) {
2459  if (i != dim)
2460  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2461  }
2462  auto reductionMap =
2463  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2464  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2465  return std::make_tuple(iteratorTypes, indexingMaps);
2466 }
2467 
2468 // Helper function to produce a linalg.generic that computes a reduction on
2469 // dimension \p dim with the operation type \p T.
2470 template <typename T>
2471 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2472  int64_t dim) {
2473  auto inputType = cast<ShapedType>(input.getType());
2474  ArrayRef<int64_t> inputShape = inputType.getShape();
2475  int64_t inputRank = inputShape.size();
2476  auto [iteratorTypes, indexingMaps] =
2477  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2478  assert(indexingMaps.size() == 2 &&
2479  "We should have two maps: 1 for the input, 1 for the output");
2480  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2481 
2482  auto genericOp = builder.create<linalg::GenericOp>(
2483  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2484  [&](OpBuilder &b, Location loc, ValueRange args) {
2485  Value result = b.create<T>(loc, args[0], args[1]);
2486  b.create<linalg::YieldOp>(loc, result);
2487  });
2488  return genericOp.getResult(0);
2489 }
2490 
2491 /// Produce a linalg generic that computes the second step of the softmax
2492 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2493 /// on dimension \p dim.
2494 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2495  Value max, Value output, int64_t dim) {
2496  auto inputType = cast<ShapedType>(input.getType());
2497  ArrayRef<int64_t> inputShape = inputType.getShape();
2498  int64_t inputRank = inputShape.size();
2499  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2500  builder, inputRank, dim, /*allParallel=*/true);
2501  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2502  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2503  // Add the affine map for the output argument.
2504  indexingMaps.push_back(indexingMaps[0]);
2505  auto genericOp = builder.create<linalg::GenericOp>(
2506  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2507  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2508  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2509  Value result = b.create<math::ExpOp>(loc, diff);
2510  b.create<linalg::YieldOp>(loc, result);
2511  });
2512  return genericOp.getResult(0);
2513 }
2514 
2515 /// Produce a linalg generic that computes the final step of the softmax
2516 /// decomposition.
2517 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2518 /// yield n / d
2519 /// }
2520 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2521  Value denominator, Value output, int64_t dim) {
2522  auto inputType = cast<ShapedType>(numerator.getType());
2523  ArrayRef<int64_t> inputShape = inputType.getShape();
2524  int64_t inputRank = inputShape.size();
2525  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2526  builder, inputRank, dim, /*allParallel=*/true);
2527  assert(indexingMaps.size() == 2 &&
2528  "We should have one map for each input (2)");
2529  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2530  // Add the affine map for the output tensor.
2531  indexingMaps.push_back(indexingMaps[0]);
2532  auto genericOp = builder.create<linalg::GenericOp>(
2533  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2534  indexingMaps, iteratorTypes,
2535  [&](OpBuilder &b, Location loc, ValueRange args) {
2536  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2537  b.create<linalg::YieldOp>(loc, result);
2538  });
2539  return genericOp.getResult(0);
2540 }
2541 // @} End helper functions for softmax decomposition.
2542 
2543 /// Given an N-dimensional tensor x, this method converts
2544 /// softmax(x) to the following sequence of operations:
2545 ///
2546 /// 1. Compute the max of x along dimension d. This results
2547 /// in a N-1 dimensional tensor m.
2548 /// m = max(x, dim = d)
2549 ///
2550 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2551 /// a N dimensional tensor z.
2552 /// z = exp(x - m)
2553 ///
2554 /// 3. Compute the sum of z along dimension d. This results in
2555 /// a N-1 dimensional tensor l.
2556 /// l = sum(z, dim = d)
2557 ///
2558 /// 4. Divide z and l. This gives the N-dimensional softmax.
2559 /// softmax = z / l
2560 ///
2561 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2562  OpBuilder::InsertionGuard guard(b);
2563  b.setInsertionPoint(*this);
2564  Location loc = getLoc();
2565  Value input = getInput();
2566  ShapedType inputType = getInputOperandType();
2567  Type elementType = inputType.getElementType();
2568  int64_t reductionDim = getDimension();
2569  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2570  Value output = getOutput();
2571  dims.erase(dims.begin() + reductionDim);
2572  // Step 1: Compute max along dim.
2573  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2574  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2575  elementType, b, loc,
2576  /*useOnlyFiniteValue=*/true);
2577  Value neutralForMaxFInit =
2578  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2579  .result();
2580  Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit,
2581  reductionDim);
2582 
2583  // Step 2: Subtract max from input and exponentiate.
2584  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2585 
2586  // Step 3: Compute sum along dim.
2587  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2588  b, loc, /*useOnlyFiniteValue=*/true);
2589  Value zeroInit =
2590  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2591  Value denominator =
2592  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2593 
2594  // Step 4: Compute softmax.
2595  Value result =
2596  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2597  return SmallVector<Value>{result};
2598 }
2599 
2600 //===----------------------------------------------------------------------===//
2601 // LinalgDialect
2602 //===----------------------------------------------------------------------===//
2603 
2604 void LinalgDialect::getCanonicalizationPatterns(
2605  RewritePatternSet &results) const {
2606  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2607  InferStaticShapeOfOperands>(getContext());
2608 }
2609 
2611  Attribute value, Type type,
2612  Location loc) {
2613  return arith::ConstantOp::materialize(builder, value, type, loc);
2614 }
static bool hasTensorSemantics(Operation *op)
Return true if the given op has a tensor result or a tensor operand.
Definition: Bufferize.cpp:356
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1665
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:271
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2449
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2520
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:118
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2016
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, ValueRange results, const ValueRange inputOperands, ValueRange outputOperands)
Definition: LinalgOps.cpp:1053
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:259
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:330
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1498
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1546
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2494
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1304
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:155
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1177
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2471
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1929
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:297
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:828
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:290
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:50
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1231
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1333
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:323
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:185
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:290
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:312
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:391
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
OpListType & getOperations()
Definition: Block.h:130
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:376
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:419
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_iterator result_end()
Definition: Operation.h:409
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:105
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1172
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2503
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2008
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:97
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2048
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:1988
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:1999
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:88
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
MPInt ceil(const Fraction &f)
Definition: Fraction.h:76
MPInt floor(const Fraction &f)
Definition: Fraction.h:74
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:340
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:168
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:348
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:584
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.