MLIR  20.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
32 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/PatternMatch.h"
38 
39 #include "llvm/ADT/DenseMap.h"
40 #include "llvm/ADT/SmallSet.h"
41 #include "llvm/ADT/StringSet.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "llvm/Support/MathExtras.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include <optional>
47 
48 using namespace mlir;
49 using namespace mlir::linalg;
50 
51 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
53  int64_t dim) {
54  auto type = cast<ShapedType>(v.getType());
55  if (!type.isDynamicDim(dim))
56  return builder.getIndexAttr(type.getDimSize(dim));
57 
58  return getAsOpFoldResult(
60  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
61  return builder.create<tensor::DimOp>(loc, v, dim);
62  })
63  .Case<MemRefType>([&](MemRefType t) -> Value {
64  return builder.create<memref::DimOp>(loc, v, dim);
65  }));
66 }
67 
68 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
69 /// `source`.
70 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
71  ArrayRef<OpFoldResult> offsets,
73  ArrayRef<OpFoldResult> strides) {
74  return TypeSwitch<Type, Operation *>(source.getType())
75  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
76  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
77  strides);
78  })
79  .Case<MemRefType>([&](MemRefType type) -> Operation * {
80  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
81  strides);
82  })
83  .Default([&](Type t) -> Operation * { return nullptr; });
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Helper functions
88 //===----------------------------------------------------------------------===//
89 
91  int64_t dim) {
92  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
93  return b.createOrFold<memref::DimOp>(loc, source, dim);
94  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
95  return b.createOrFold<tensor::DimOp>(loc, source, dim);
96  llvm_unreachable("Expected MemRefType or TensorType");
97 }
98 
100  int64_t dim) {
101  auto shapedType = llvm::cast<ShapedType>(source.getType());
102  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
103  return createOrFoldDimOp(b, loc, source, dim);
104  return b.getIndexAttr(shapedType.getDimSize(dim));
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Support for named Linalg ops defined in ods-gen.
109 //===----------------------------------------------------------------------===//
110 
113 
114 /// Fills the region of a structured operation using the provided
115 /// `regionBuilder`. The method is used by both named structured ops created by
116 /// ods-gen and by manually defined C++ ops. It is called by both builders and
117 /// parsers and creates a block with arguments corresponding to the elemental
118 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
119 /// ShapedType.
120 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
121  TypeRange inputTypes, TypeRange outputTypes,
123  RegionBuilderFn regionBuilder) {
124  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
125 
126  SmallVector<Type, 8> argTypes;
127  SmallVector<Location, 8> argLocs;
128  for (auto containers : {inputTypes, outputTypes}) {
129  for (auto t : containers) {
130  argTypes.push_back(
131  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
132 
133  // TODO: Pass in a proper location here.
134  argLocs.push_back(opBuilder.getUnknownLoc());
135  }
136  }
137 
138  // RAII.
139  OpBuilder::InsertionGuard guard(opBuilder);
140  Block *body =
141  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
142 
143  opBuilder.setInsertionPointToStart(body);
144  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
145  regionBuilder(b, *body, attrs);
146 
147  // indexing_maps is an auto-generated method.
148 
149  // iterator_types is an auto-generated method.
150 }
151 
152 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
153 /// The result types are derived automatically if `resultTensorTypes` is none.
154 /// The body of the operation is filled using `regionBuilder`. All ods-gen
155 /// created structured operations use the method to implement their builders.
157  std::optional<TypeRange> resultTensorTypes,
158  ValueRange inputs, ValueRange outputs,
159  ArrayRef<NamedAttribute> attributes,
160  RegionBuilderFn regionBuilder) {
161  // Derive the result types if needed.
162  SmallVector<Type> derivedResultTypes =
163  resultTensorTypes.value_or(TypeRange());
164  if (!resultTensorTypes)
165  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
166  llvm::IsaPred<RankedTensorType>);
167 
168  state.addOperands(inputs);
169  state.addOperands(outputs);
170  state.addTypes(derivedResultTypes);
171  state.addAttributes(attributes);
172  state.addAttribute(
173  "operandSegmentSizes",
174  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
175  static_cast<int32_t>(outputs.size())}));
176 
177  // Create and fill the region of the structured operation.
178  Region &region = *state.addRegion();
179  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
180  state.attributes.getAttrs(), regionBuilder);
181 }
182 
183 /// Common parsing used for both named structured ops created by ods-gen and by
184 /// manually defined C++ ops. Does not handle regions.
185 static ParseResult
187  SmallVectorImpl<Type> &inputTypes,
188  SmallVectorImpl<Type> &outputTypes,
189  bool addOperandSegmentSizes = true) {
190  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
192  outputsOperands;
193 
194  if (succeeded(parser.parseOptionalLess())) {
195  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
196  return failure();
197  }
198  attrsLoc = parser.getCurrentLocation();
199  if (parser.parseOptionalAttrDict(result.attributes))
200  return failure();
201 
202  if (succeeded(parser.parseOptionalKeyword("ins"))) {
203  if (parser.parseLParen())
204  return failure();
205 
206  inputsOperandsLoc = parser.getCurrentLocation();
207  if (parser.parseOperandList(inputsOperands) ||
208  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
209  return failure();
210  }
211 
212  if (succeeded(parser.parseOptionalKeyword("outs"))) {
213  outputsOperandsLoc = parser.getCurrentLocation();
214  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
215  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
216  return failure();
217  }
218 
219  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
220  result.operands) ||
221  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
222  result.operands))
223  return failure();
224 
225  if (addOperandSegmentSizes) {
226  // This is a bit complex because we're trying to be backward compatible with
227  // operation syntax that mix the inherent attributes and the discardable
228  // ones in the same dictionary. If the properties are used, we append the
229  // operandSegmentSizes there directly. Otherwise we append it to the
230  // discardable attributes dictionary where it is handled by the generic
231  // Operation::create(...) method.
232  if (result.propertiesAttr) {
233  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
234  attrs.append("operandSegmentSizes",
236  {static_cast<int32_t>(inputsOperands.size()),
237  static_cast<int32_t>(outputsOperands.size())}));
238  result.propertiesAttr = attrs.getDictionary(parser.getContext());
239  } else {
240  result.addAttribute("operandSegmentSizes",
242  {static_cast<int32_t>(inputsOperands.size()),
243  static_cast<int32_t>(outputsOperands.size())}));
244  }
245  }
246  if (!result.propertiesAttr) {
247  std::optional<RegisteredOperationName> info =
248  result.name.getRegisteredInfo();
249  if (info) {
250  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
251  return parser.emitError(attrsLoc)
252  << "'" << result.name.getStringRef() << "' op ";
253  })))
254  return failure();
255  }
256  }
257  return success();
258 }
259 
261  ValueRange outputs) {
262  if (!inputs.empty())
263  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
264  if (!outputs.empty())
265  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // Specific parsing and printing for named structured ops created by ods-gen.
270 //===----------------------------------------------------------------------===//
271 
272 static ParseResult parseNamedStructuredOpRegion(
273  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
274  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
275  RegionBuilderFn regionBuilder) {
276  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
277  return parser.emitError(
278  parser.getCurrentLocation(),
279  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
280  "region expects {0} args, got {1}",
281  numRegionArgs, inputTypes.size() + outputTypes.size()));
282  }
283 
284  OpBuilder opBuilder(parser.getContext());
285  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
286  regionBuilder);
287  return success();
288 }
289 
290 static ParseResult
292  SmallVectorImpl<Type> &resultTypes) {
293  if (parser.parseOptionalArrowTypeList(resultTypes))
294  return failure();
295  return success();
296 }
297 
298 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
299  OperationState &result,
300  unsigned numRegionArgs,
301  RegionBuilderFn regionBuilder) {
302  // TODO: Enable when ods-gen supports captures.
303  SmallVector<Type, 1> inputTypes, outputTypes;
304  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
305  return failure();
306 
307  // TODO: consider merging results parsing into region parsing.
308  // Need to wait for declarative assembly resolution to decide.
309  SmallVector<Type, 1> outputTensorsTypes;
310  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
311  return failure();
312  result.addTypes(outputTensorsTypes);
313 
314  std::unique_ptr<Region> region = std::make_unique<Region>();
315  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
316  outputTypes, result.attributes.getAttrs(),
317  regionBuilder))
318  return failure();
319  result.addRegion(std::move(region));
320 
321  return success();
322 }
323 
325  TypeRange resultTypes) {
326  if (resultTypes.empty())
327  return;
328  p.printOptionalArrowTypeList(resultTypes);
329 }
330 
332  ValueRange inputs, ValueRange outputs) {
334  op->getAttrs(),
335  /*elidedAttrs=*/{"operandSegmentSizes",
336  // See generated code in
337  // LinalgNamedStructuredOps.yamlgen.cpp.inc
338  "linalg.memoized_indexing_maps"});
339 
340  // Printing is shared with generic ops, except for the region and
341  // attributes.
342  printCommonStructuredOpParts(p, inputs, outputs);
343 
344  // Results printing.
346 
347  // Region is elided.
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // Region builder helper.
352 // TODO: Move this to a utility library.
353 // The public methods on this class are referenced directly from generated code.
354 // Helper build the unary, binary, and type conversion functions defined by the
355 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
356 // class.
357 //
358 // Implementations of the math functions must be polymorphic over numeric types,
359 // internally performing necessary casts. If the function application makes no
360 // sense, then the only recourse is to assert and return nullptr. This can be
361 // extended later if it becomes possible to fail construction of the region. The
362 // invariant should be enforced at a higher level.
363 //
364 // TODO: These helpers are currently type polymorphic over the class of integer
365 // and floating point types, but they will not internally cast within bit
366 // widths of a class (mixed precision such as i8->i32) or across classes
367 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
368 // to be handled with care and work is being considered to extend the op
369 // language to make such cases explicit. In the mean-time, violating this will
370 // fail verification, which is deemed acceptable.
371 //===----------------------------------------------------------------------===//
372 
373 namespace {
374 
375 class RegionBuilderHelper {
376 public:
377  RegionBuilderHelper(OpBuilder &builder, Block &block)
378  : builder(builder), block(block) {}
379 
380  // Build the unary functions defined by OpDSL.
381  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
382  if (!isFloatingPoint(arg))
383  llvm_unreachable("unsupported non numeric type");
384  OpBuilder::InsertionGuard g(builder);
385  builder.setInsertionPointToEnd(&block);
386  switch (unaryFn) {
387  case UnaryFn::exp:
388  return builder.create<math::ExpOp>(arg.getLoc(), arg);
389  case UnaryFn::log:
390  return builder.create<math::LogOp>(arg.getLoc(), arg);
391  case UnaryFn::abs:
392  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
393  case UnaryFn::ceil:
394  return builder.create<math::CeilOp>(arg.getLoc(), arg);
395  case UnaryFn::floor:
396  return builder.create<math::FloorOp>(arg.getLoc(), arg);
397  case UnaryFn::negf:
398  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
399  case UnaryFn::reciprocal: {
400  Attribute oneAttr = builder.getOneAttr(arg.getType());
401  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
402  ::cast<TypedAttr>(oneAttr));
403  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
404  }
405  case UnaryFn::round:
406  return builder.create<math::RoundOp>(arg.getLoc(), arg);
407  case UnaryFn::sqrt:
408  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
409  case UnaryFn::rsqrt:
410  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
411  case UnaryFn::square:
412  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
413  case UnaryFn::tanh:
414  return builder.create<math::TanhOp>(arg.getLoc(), arg);
415  case UnaryFn::erf:
416  return builder.create<math::ErfOp>(arg.getLoc(), arg);
417  }
418  llvm_unreachable("unsupported unary function");
419  }
420 
421  // Build the binary functions defined by OpDSL.
422  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
423  bool allComplex = isComplex(arg0) && isComplex(arg1);
424  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
425  bool allInteger = isInteger(arg0) && isInteger(arg1);
426  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
427  arg1.getType().getIntOrFloatBitWidth() == 1;
428  if (!allComplex && !allFloatingPoint && !allInteger)
429  llvm_unreachable("unsupported non numeric type");
430  OpBuilder::InsertionGuard g(builder);
431  builder.setInsertionPointToEnd(&block);
432  switch (binaryFn) {
433  case BinaryFn::add:
434  if (allComplex)
435  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
436  if (allFloatingPoint)
437  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
438  if (allBool)
439  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
440  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
441  case BinaryFn::sub:
442  if (allComplex)
443  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
444  if (allFloatingPoint)
445  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
446  if (allBool)
447  llvm_unreachable("unsupported operation: sub with bools");
448  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
449  case BinaryFn::mul:
450  if (allComplex)
451  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
452  if (allFloatingPoint)
453  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
454  if (allBool)
455  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
456  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
457  case BinaryFn::div:
458  if (allComplex)
459  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
460  if (allFloatingPoint)
461  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
462  if (allBool)
463  llvm_unreachable("unsupported operation: div with bools");
464  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
465  case BinaryFn::div_unsigned:
466  if (!allInteger || allBool)
467  llvm_unreachable("unsupported operation: unsigned div not on uint");
468  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
469  case BinaryFn::max_signed:
470  assert(!allComplex);
471  if (allFloatingPoint)
472  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
473  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
474  case BinaryFn::min_signed:
475  assert(!allComplex);
476  if (allFloatingPoint)
477  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
478  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
479  case BinaryFn::max_unsigned:
480  assert(!allComplex);
481  if (allFloatingPoint)
482  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
483  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
484  case BinaryFn::min_unsigned:
485  assert(!allComplex);
486  if (allFloatingPoint)
487  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
488  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
489  case BinaryFn::powf:
490  assert(allFloatingPoint);
491  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
492  }
493  llvm_unreachable("unsupported binary function");
494  }
495 
496  // Build the ternary functions defined by OpDSL.
497  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
498  Value arg2) {
499  bool headBool =
500  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
501  bool tailFloatingPoint =
502  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
503  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
504  OpBuilder::InsertionGuard g(builder);
505  builder.setInsertionPointToEnd(&block);
506  switch (ternaryFn) {
507  case TernaryFn::select:
508  if (!headBool && !(tailFloatingPoint || tailInteger))
509  llvm_unreachable("unsupported non numeric type");
510  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
511  }
512  llvm_unreachable("unsupported ternary function");
513  }
514 
515  // Build the type functions defined by OpDSL.
516  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
517  switch (typeFn) {
518  case TypeFn::cast_signed:
519  return cast(toType, operand, false);
520  case TypeFn::cast_unsigned:
521  return cast(toType, operand, true);
522  }
523  llvm_unreachable("unsupported type conversion function");
524  }
525 
526  void yieldOutputs(ValueRange values) {
527  OpBuilder::InsertionGuard g(builder);
528  builder.setInsertionPointToEnd(&block);
529  Location loc = builder.getUnknownLoc();
530  builder.create<YieldOp>(loc, values);
531  }
532 
533  Value constant(const std::string &value) {
534  OpBuilder::InsertionGuard g(builder);
535  builder.setInsertionPointToEnd(&block);
536  Location loc = builder.getUnknownLoc();
537  Attribute valueAttr = parseAttribute(value, builder.getContext());
538  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
539  }
540 
541  Value index(int64_t dim) {
542  OpBuilder::InsertionGuard g(builder);
543  builder.setInsertionPointToEnd(&block);
544  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
545  }
546 
547  Type getIntegerType(unsigned width) {
548  return IntegerType::get(builder.getContext(), width);
549  }
550 
551  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
552  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
553 
554 private:
555  // Generates operations to cast the given operand to a specified type.
556  // If the cast cannot be performed, a warning will be issued and the
557  // operand returned as-is (which will presumably yield a verification
558  // issue downstream).
559  Value cast(Type toType, Value operand, bool isUnsignedCast) {
560  OpBuilder::InsertionGuard g(builder);
561  builder.setInsertionPointToEnd(&block);
562  auto loc = operand.getLoc();
563  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
564  }
565 
566  bool isComplex(Value value) {
567  return llvm::isa<ComplexType>(value.getType());
568  }
569  bool isFloatingPoint(Value value) {
570  return llvm::isa<FloatType>(value.getType());
571  }
572  bool isInteger(Value value) {
573  return llvm::isa<IntegerType>(value.getType());
574  }
575 
576  OpBuilder &builder;
577  Block &block;
578 };
579 
580 } // namespace
581 
582 //===----------------------------------------------------------------------===//
583 // CopyOp
584 //===----------------------------------------------------------------------===//
585 
586 namespace {
587 
588 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
590  LogicalResult matchAndRewrite(CopyOp copyOp,
591  PatternRewriter &rewriter) const override {
592  if (copyOp.getInputs() != copyOp.getOutputs())
593  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
594  if (copyOp.hasPureBufferSemantics())
595  rewriter.eraseOp(copyOp);
596  else
597  rewriter.replaceOp(copyOp, copyOp.getInputs());
598 
599  return success();
600  }
601 };
602 
603 } // namespace
604 
605 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
606  MLIRContext *context) {
607  results.add<EraseSelfCopy>(context);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // FillOp
612 //===----------------------------------------------------------------------===//
613 
614 namespace {
615 
616 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
617 ///
618 /// For such op chains, we can create new linalg.fill ops with the result
619 /// type of the tensor.expand/collapse_shape op.
620 template <typename TensorReshapeOp>
621 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
623  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
624  PatternRewriter &rewriter) const override {
625  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
626  if (!oldFill)
627  return failure();
628 
629  Location loc = oldFill.getLoc();
630  TensorReshapeOp newInit;
631  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
632 
633  newInit = rewriter.create<TensorReshapeOp>(
634  loc, reshapeOp.getResultType(), oldFill.output(),
635  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
636  reshapeOp.getStaticOutputShape());
637  } else {
638  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
639  oldFill.output(),
640  reshapeOp.getReassociation());
641  }
642  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
643  ValueRange{newInit});
644  return success();
645  }
646 };
647 
648 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
649 /// filling value are the same.
650 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
652 
653  LogicalResult matchAndRewrite(tensor::PadOp padOp,
654  PatternRewriter &rewriter) const override {
655  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
656  if (!fillOp)
657  return failure();
658 
659  // We can only fold if the padding value is the same as the original
660  // filling value.
661  Value padValue = padOp.getConstantPaddingValue();
662  if (!padValue || fillOp.value() != padValue)
663  return failure();
664 
665  ReifiedRankedShapedTypeDims reifiedShape;
666  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
667  return rewriter.notifyMatchFailure(
668  padOp, "failed to reify tensor.pad op result shape");
669 
670  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
671  padOp.getLoc(), reifiedShape.front(),
672  padOp.getResultType().getElementType());
673  Value replacement =
674  rewriter
675  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
676  ValueRange{emptyTensor})
677  .getResult(0);
678  if (replacement.getType() != padOp.getResultType()) {
679  replacement = rewriter.create<tensor::CastOp>(
680  fillOp.getLoc(), padOp.getResultType(), replacement);
681  }
682  rewriter.replaceOp(padOp, replacement);
683  return success();
684  }
685 };
686 
687 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
688 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
689 /// filling value are the same.
690 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
692 
693  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
694  PatternRewriter &rewriter) const override {
695  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
696  if (!srcPadOp)
697  return failure();
698 
699  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
700  return failure();
701 
702  // Walk back the tensor.insert_slice chain and find the first destination
703  // value at the start of the chain.
704  Value firstDest = insertOp.getDest();
705  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
706  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
707  return failure();
708 
709  // Make sure the range of values accessed are disjoint. Without this, we
710  // cannot fold tensor.pad away.
711  bool disjoint = false;
712  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
713  // If the dimension has dynamic offset/size, we cannot guarantee
714  // disjoint. So just skip it.
715  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
716  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
717  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
718  continue;
719 
720  // Get the range start and end, inclusively for both.
721  int64_t prevStart = prevOp.getStaticOffset(i);
722  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
723  prevOp.getStaticStride(i);
724  int64_t nextStart = insertOp.getStaticOffset(i);
725  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
726  insertOp.getStaticStride(i);
727  if (prevEnd < nextStart || nextEnd < prevStart) {
728  disjoint = true;
729  break;
730  }
731  }
732 
733  if (!disjoint)
734  break;
735  firstDest = prevOp.getDest();
736  }
737 
738  // Check whether the first destination is a fill op. For overlapped cases,
739  // this also cannot be true.
740  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
741  if (!dstFillOp)
742  return failure();
743 
744  // We can only fold if the padding value is the same as the original
745  // filling value.
746  Value padValue = srcPadOp.getConstantPaddingValue();
747  if (!padValue || dstFillOp.value() != padValue)
748  return failure();
749 
750  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
751  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
752 
753  Location loc = insertOp.getLoc();
754  MLIRContext *context = getContext();
755 
756  AffineExpr sym0, sym1;
757  bindSymbols(context, sym0, sym1);
758  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
759 
760  // Calculate the new offsets for the insert. It should be the old offsets
761  // plus low padding sizes.
762  SmallVector<OpFoldResult, 4> newOffsets;
763  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
764  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
765  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
766  }
767 
768  RankedTensorType srcPadType = srcPadOp.getSourceType();
770  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
771  if (srcPadType.isDynamicDim(i)) {
772  newSizes.push_back(
773  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
774  .getResult());
775  } else {
776  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
777  }
778  }
779 
780  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
781  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
782  newSizes, insertOp.getMixedStrides());
783  return success();
784  }
785 };
786 
787 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
788 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
789 public:
791 
792  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
793  PatternRewriter &rewriter) const override {
794  // See if tensor input of tensor.extract op is the result of a linalg.fill
795  // op.
796  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
797  if (!fillOp)
798  return failure();
799 
800  // Get scalar input operand of linalg.fill op.
801  Value extractedScalar = fillOp.getInputs()[0];
802 
803  // Replace tensor.extract op with scalar value used to fill the tensor.
804  rewriter.replaceOp(extractOp, extractedScalar);
805  return success();
806  }
807 };
808 
809 /// Folds pack(fill) into a single fill op if
810 /// 1. The pack op does not have padding value, or
811 /// 2. The filled value and padding value are the same.
812 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
813  tensor::PackOp packOp) {
814  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
815  if (!fillOp)
816  return failure();
817 
818  if (auto paddingValue = packOp.getPaddingValue())
819  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
820  return failure();
821 
822  Value packOpDest = packOp.getDest();
823  if (!packOpDest.hasOneUse())
824  return failure();
825 
826  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
827  packOp.getDest());
828 }
829 
830 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
831 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
832 public:
833  FoldFillWithPack(MLIRContext *context)
834  : OpRewritePattern<tensor::PackOp>(context) {}
835 
836  LogicalResult matchAndRewrite(tensor::PackOp packOp,
837  PatternRewriter &rewriter) const override {
838  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
839  if (failed(fillOp))
840  return failure();
841  rewriter.replaceOp(packOp, fillOp.value().result());
842  return success();
843  }
844 };
845 
846 /// Fold fill with copy.
847 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
849 
850  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
851  PatternRewriter &rewriter) const override {
852  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
853  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
854  fillOp.getInputs(),
855  copyOp.getOutputs());
856  return success();
857  }
858  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
859  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
860  fillOp.getOutputs());
861  return success();
862  }
863  return failure();
864  }
865 };
866 
867 /// Fold fill with transpose.
868 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
870 
871  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
872  PatternRewriter &rewriter) const override {
873  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
874  rewriter.replaceOpWithNewOp<FillOp>(
875  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
876  transposeOp.getDpsInitOperand(0)->get());
877  return success();
878  }
879  return failure();
880  }
881 };
882 
883 /// Fold a concat with all elements being fills of the same value
884 /// into a fill of the concat result shape.
885 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
887 
888  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
889  PatternRewriter &rewriter) const override {
890  auto concatOperands = concatOp.getInputs();
891  if (concatOperands.empty()) {
892  return failure();
893  }
894 
895  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
896  if (!firstFillOp) {
897  return failure();
898  }
899  // Prefetch the fill value.
900  OpFoldResult firstFillVal =
901  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
902  // Collect all the outs values for the fill operations.
903  SmallVector<Value> allOuts;
904  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
905 
906  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
907  auto fillOp = v.getDefiningOp<linalg::FillOp>();
908  if (!fillOp) {
909  return false;
910  }
911 
912  OpFoldResult fillVal =
913  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
914  if (fillVal != firstFillVal)
915  return false;
916 
917  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
918  return true;
919  };
920  if (!llvm::all_of(concatOperands.drop_front(),
921  isDefinedByCompatibleFillOp)) {
922  return rewriter.notifyMatchFailure(
923  concatOp, "not all operands are defined by a compatible fill op");
924  }
925 
926  Value outsConcat = rewriter.create<tensor::ConcatOp>(
927  concatOp.getLoc(), concatOp.getDim(), allOuts);
928  rewriter.replaceOpWithNewOp<linalg::FillOp>(
929  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
930  return success();
931  }
932 };
933 
934 } // namespace
935 
936 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
937  MLIRContext *context) {
938  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
939  FoldFillWithPack, FoldFillWithPad,
940  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
941  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
942  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // GenericOp
947 //===----------------------------------------------------------------------===//
948 
949 static void buildGenericRegion(
950  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
951  ValueRange outputs,
952  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
953  SmallVector<Type, 4> blockArgTypes;
954  SmallVector<Location, 4> blockArgLocs;
955  for (ValueRange container : {inputs, outputs}) {
956  for (Value v : container) {
957  Type t = v.getType();
958  blockArgTypes.push_back(
959  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
960  blockArgLocs.push_back(v.getLoc());
961  }
962  }
963 
964  OpBuilder::InsertionGuard guard(builder);
965  Block *bodyBlock =
966  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
967  bodyBuild(builder, loc, bodyBlock->getArguments());
968 }
969 
970 void GenericOp::getAsmBlockArgumentNames(Region &region,
971  OpAsmSetValueNameFn setNameFn) {
972  for (Value v : getRegionInputArgs())
973  setNameFn(v, "in");
974  for (Value v : getRegionOutputArgs())
975  setNameFn(v, "out");
976 }
977 
978 void GenericOp::build(
979  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
980  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
981  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
982  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
983  ArrayRef<NamedAttribute> attributes) {
984  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
985  iteratorTypes, doc, libraryCall);
986  result.addAttributes(attributes);
987  if (bodyBuild)
988  buildGenericRegion(builder, result.location, *result.regions.front(),
989  inputs, outputs, bodyBuild);
990 }
991 
992 void GenericOp::build(
993  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
994  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
995  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
996  StringRef libraryCall,
997  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
998  ArrayRef<NamedAttribute> attributes) {
999  build(builder, result, resultTensorTypes, inputs, outputs,
1000  builder.getAffineMapArrayAttr(indexingMaps),
1001  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1002  iteratorTypes,
1003  [&](utils::IteratorType iter) -> mlir::Attribute {
1004  return IteratorTypeAttr::get(builder.getContext(), iter);
1005  }))),
1006  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1007  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1008  bodyBuild, attributes);
1009 }
1010 
1011 void GenericOp::build(
1012  OpBuilder &builder, OperationState &result, ValueRange inputs,
1013  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1014  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1015  StringRef libraryCall,
1016  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1017  ArrayRef<NamedAttribute> attributes) {
1018  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1019  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1020 }
1021 
1022 void GenericOp::build(
1023  OpBuilder &builder, OperationState &result, ValueRange inputs,
1024  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1025  ArrayRef<utils::IteratorType> iteratorTypes,
1026  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1027  ArrayRef<NamedAttribute> attributes) {
1028  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1029  /*doc=*/"",
1030  /*libraryCall=*/"", bodyBuild, attributes);
1031 }
1032 
1033 void GenericOp::build(
1034  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1035  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1036  ArrayRef<utils::IteratorType> iteratorTypes,
1037  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1038  ArrayRef<NamedAttribute> attributes) {
1039  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1040  iteratorTypes,
1041  /*doc=*/"",
1042  /*libraryCall=*/"", bodyBuild, attributes);
1043 }
1044 
1045 void GenericOp::print(OpAsmPrinter &p) {
1046  p << " ";
1047 
1048  // Print extra attributes.
1049  auto genericAttrNames = linalgTraitAttrNames();
1050 
1051  llvm::StringSet<> genericAttrNamesSet;
1052  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1053  SmallVector<NamedAttribute, 8> genericAttrs;
1054  for (auto attr : (*this)->getAttrs()) {
1055  if (attr.getName() == getIteratorTypesAttrName()) {
1056  auto iteratorTypes =
1057  llvm::cast<ArrayAttr>(attr.getValue())
1058  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1059  // Convert IteratorType enums into the string representation. This is
1060  // needed, because tests still use the old format when 'iterator_types'
1061  // attribute is represented as an array of strings.
1062  // TODO: Remove this conversion once tests are fixed.
1063  SmallVector<Attribute> iteratorTypeNames =
1064  llvm::to_vector(llvm::map_range(
1065  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1066  return StringAttr::get(getContext(), stringifyIteratorType(t));
1067  }));
1068 
1069  genericAttrs.emplace_back(
1070  getIteratorTypesAttrName(),
1071  ArrayAttr::get(getContext(), iteratorTypeNames));
1072  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1073  genericAttrs.push_back(attr);
1074  }
1075  }
1076  if (!genericAttrs.empty()) {
1077  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1078  p << genericDictAttr;
1079  }
1080 
1081  // Printing is shared with named ops, except for the region and attributes
1082  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1083 
1084  genericAttrNames.push_back("operandSegmentSizes");
1085  genericAttrNamesSet.insert(genericAttrNames.back());
1086 
1087  bool hasExtraAttrs = false;
1088  for (NamedAttribute n : (*this)->getAttrs()) {
1089  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1090  break;
1091  }
1092  if (hasExtraAttrs) {
1093  p << " attrs = ";
1094  p.printOptionalAttrDict((*this)->getAttrs(),
1095  /*elidedAttrs=*/genericAttrNames);
1096  }
1097 
1098  // Print region.
1099  if (!getRegion().empty()) {
1100  p << ' ';
1101  p.printRegion(getRegion());
1102  }
1103 
1104  // Print results.
1105  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1106 }
1107 
1108 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1109  DictionaryAttr dictAttr;
1110  // Parse the core linalg traits that must check into a dictAttr.
1111  // The name is unimportant as we will overwrite result.attributes.
1112  // The core linalg traits must contain the information necessary to pass the
1113  // verifier.
1114  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1115  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1116  return failure();
1117  result.attributes.assign(dictAttr.getValue().begin(),
1118  dictAttr.getValue().end());
1119 
1120  // Convert array of string into an array of IteratorType enums. This is
1121  // needed, because tests still use the old format when 'iterator_types'
1122  // attribute is represented as an array of strings.
1123  // TODO: Remove this conversion once tests are fixed.
1124  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1125  result.attributes.get(getIteratorTypesAttrName(result.name)));
1126  if (!iteratorTypes) {
1127  return parser.emitError(attributeLocation)
1128  << "expected " << getIteratorTypesAttrName(result.name)
1129  << " array attribute";
1130  }
1131 
1132  SmallVector<Attribute> iteratorTypeAttrs;
1133 
1134  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1135  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1136  if (!maybeIteratorType.has_value())
1137  return parser.emitError(parser.getCurrentLocation())
1138  << "unexpected iterator_type (" << s << ")";
1139 
1140  iteratorTypeAttrs.push_back(
1141  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1142  }
1143  result.attributes.set(getIteratorTypesAttrName(result.name),
1144  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1145 
1146  // Parsing is shared with named ops, except for the region.
1147  SmallVector<Type, 1> inputTypes, outputTypes;
1148  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1149  return failure();
1150 
1151  // Optional attributes may be added.
1152  if (succeeded(parser.parseOptionalKeyword("attrs")))
1153  if (failed(parser.parseEqual()) ||
1154  failed(parser.parseOptionalAttrDict(result.attributes)))
1155  return failure();
1156 
1157  std::unique_ptr<Region> region = std::make_unique<Region>();
1158  if (parser.parseRegion(*region, {}))
1159  return failure();
1160  result.addRegion(std::move(region));
1161 
1162  // Generic ops may specify that a subset of its outputs are tensors. Such
1163  // outputs are specified in the result type.
1164  // TODO: may need to move output parsing before region parsing.
1165  // Need to wait for declarative assembly resolution to decide.
1166  SmallVector<Type, 1> outputTensorsTypes;
1167  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1168  return failure();
1169  result.addTypes(outputTensorsTypes);
1170 
1171  return success();
1172 }
1173 
1176  &effects,
1177  LinalgOp linalgOp) {
1178  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1179  if (!llvm::isa<MemRefType>(operand.getType()))
1180  continue;
1181  effects.emplace_back(
1182  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1183  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1184  }
1185 
1186  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1187  if (!llvm::isa<MemRefType>(operand.get().getType()))
1188  continue;
1189  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1190  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1191  /*effectOnFullRegion=*/true,
1193  }
1194  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1195  /*effectOnFullRegion=*/true,
1197  }
1198 }
1199 
1200 void GenericOp::getEffects(
1202  &effects) {
1203  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1204 }
1205 
1207 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1208  // Operands with value semantics are speculatable, while operands with memory
1209  // semantics are not.
1210  if (!linalgOp.hasPureTensorSemantics())
1212  // The body of the op can still have speculation in its region.
1214 }
1215 
1216 Speculation::Speculatability GenericOp::getSpeculatability() {
1217  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1218 }
1219 
1220 LogicalResult GenericOp::verify() { return success(); }
1221 
1222 namespace {
1223 
1224 /// Remove any linalg operation (on tensors) that are just copying
1225 /// the values from inputs to the results. Requirements are
1226 /// 1) All iterator types are parallel
1227 /// 2) The body contains just a yield operation with the yielded values being
1228 /// the arguments corresponding to the operands.
1229 template <typename OpTy>
1230 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1232 
1233  LogicalResult matchAndRewrite(OpTy linalgOp,
1234  PatternRewriter &rewriter) const override {
1235  // All indexing maps must be equal. It follows that they are permutations.
1236  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1237  return failure();
1238 
1239  // Check that the body of the linalg operation is just a linalg.yield
1240  // operation.
1241  Block &body = linalgOp->getRegion(0).front();
1242  if (!llvm::hasSingleElement(body))
1243  return failure();
1244  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1245  if (!yieldOp)
1246  return failure();
1247 
1248  // In the buffer case, we need to check exact buffer equality.
1249  if (linalgOp.hasPureBufferSemantics()) {
1250  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1251  linalgOp.getDpsInputOperand(0)->get() ==
1252  linalgOp.getDpsInitOperand(0)->get()) {
1253  rewriter.eraseOp(linalgOp);
1254  return success();
1255  }
1256  return failure();
1257  }
1258 
1259  // Mixed semantics is not supported yet.
1260  if (!linalgOp.hasPureTensorSemantics())
1261  return failure();
1262 
1263  // Get the argument number of the returned values. That is the operand
1264  // number to use for replacing uses of this operation.
1265  SmallVector<Value> returnedArgs;
1266  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1267  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1268  if (!yieldArg || yieldArg.getOwner() != &body)
1269  return failure();
1270  unsigned argumentNumber = yieldArg.getArgNumber();
1271  Value returnedArg = linalgOp->getOperand(argumentNumber);
1272  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1273  // The input can have a different type than the result, e.g. a dynamic
1274  // input dimension can be turned into a static output dimension.
1275  Type returnType = returnedArg.getType();
1276  if (returnType != resultType) {
1277  // Distinguish between sparse conversion or dense tensor casting.
1278  // TODO: unify the two ops?
1279  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1281  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1282  linalgOp.getLoc(), resultType, returnedArg);
1283  else {
1284  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1285  resultType))
1286  return failure();
1287  returnedArg = rewriter.create<tensor::CastOp>(
1288  linalgOp.getLoc(), resultType, returnedArg);
1289  }
1290  }
1291  returnedArgs.push_back(returnedArg);
1292  }
1293 
1294  if (returnedArgs.size() != linalgOp->getNumResults())
1295  return failure();
1296  rewriter.replaceOp(linalgOp, returnedArgs);
1297  return success();
1298  }
1299 };
1300 
1301 } // namespace
1302 
1303 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1304  MLIRContext *context) {
1305  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1306 }
1307 
1308 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1309  return memref::foldMemRefCast(*this);
1310 }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // MapOp
1314 //===----------------------------------------------------------------------===//
1315 
1316 static ParseResult parseDstStyleOp(
1317  OpAsmParser &parser, OperationState &result,
1318  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1319  nullptr) {
1320  // Parse `ins` and `outs`.
1321  SmallVector<Type, 4> inputTypes, outputTypes;
1322  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1323  /*addOperandSegmentSizes=*/false))
1324  return failure();
1325 
1326  // Add result types.
1327  for (Type outputType : outputTypes) {
1328  if (llvm::isa<RankedTensorType>(outputType))
1329  result.addTypes(outputType);
1330  }
1331 
1332  // Parse required attributes.
1333  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1334  return failure();
1335 
1336  // Parse optional attributes.
1337  if (parser.parseOptionalAttrDict(result.attributes))
1338  return failure();
1339  return success();
1340 }
1341 
1342 void MapOp::getAsmBlockArgumentNames(Region &region,
1343  OpAsmSetValueNameFn setNameFn) {
1344  for (Value v : getRegionInputArgs())
1345  setNameFn(v, "in");
1346 }
1347 
1348 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1349  if (!getResults().empty())
1350  setNameFn(getResults().front(), "mapped");
1351 }
1352 
1353 void MapOp::build(
1354  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1355  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1356  ArrayRef<NamedAttribute> attributes) {
1357  build(builder, result, TypeRange{}, inputs, init);
1358  result.addAttributes(attributes);
1359 
1360  // Add output types for `RankedTensorType` output arguments.
1361  Type initType = init.getType();
1362  if (llvm::isa<RankedTensorType>(initType))
1363  result.addTypes(initType);
1364 
1365  if (bodyBuild)
1366  buildGenericRegion(builder, result.location, *result.regions.front(),
1367  inputs, /*outputs=*/{}, bodyBuild);
1368 }
1369 
1370 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1371  const OperationName &payloadOpName,
1372  const NamedAttrList &payloadOpAttrs,
1373  ArrayRef<Value> operands,
1374  bool initFirst = false) {
1375  OpBuilder b(parser.getContext());
1376  Region *body = result.addRegion();
1377  Block &block = body->emplaceBlock();
1378  b.setInsertionPointToStart(&block);
1379  SmallVector<Value> bbArgs;
1380  for (auto &operand : operands) {
1381  block.addArgument(
1382  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1383  b.getUnknownLoc());
1384  }
1385  SmallVector<Value> payloadOpOperands;
1386  // If initFirst flag is enabled, we consider init as the first position of
1387  // payload operands.
1388  if (initFirst) {
1389  payloadOpOperands.push_back(block.getArguments().back());
1390  for (const auto &arg : block.getArguments().drop_back())
1391  payloadOpOperands.push_back(arg);
1392  } else {
1393  payloadOpOperands = {block.getArguments().begin(),
1394  block.getArguments().end()};
1395  }
1396 
1397  Operation *payloadOp = b.create(
1398  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1399  payloadOpOperands,
1400  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1401  .getElementType()},
1402  payloadOpAttrs);
1403  b.create<YieldOp>(result.location, payloadOp->getResults());
1404 }
1405 
1406 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1407  std::optional<OperationName> payloadOpName;
1408  NamedAttrList payloadOpAttrs;
1409  if (succeeded(parser.parseOptionalLBrace())) {
1410  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1411  if (failed(operationName))
1412  return failure();
1413  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1414  return failure();
1415  payloadOpName = operationName.value();
1416  if (parser.parseRBrace())
1417  return failure();
1418  }
1419 
1420  if (parseDstStyleOp(parser, result))
1421  return failure();
1422 
1423  if (payloadOpName.has_value()) {
1424  if (!result.operands.empty())
1425  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1426  payloadOpAttrs,
1427  ArrayRef(result.operands).drop_back());
1428  else
1429  result.addRegion();
1430  } else {
1432  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1433  /*allowType=*/true, /*allowAttrs=*/true)) {
1434  return failure();
1435  }
1436  Region *body = result.addRegion();
1437  if (parser.parseRegion(*body, regionArgs))
1438  return failure();
1439  }
1440  return success();
1441 }
1442 
1443 // Retrieve the operation from the body, if it is the only one (except
1444 // yield) and if it gets the same amount of arguments as the body does.
1445 // If initFirst flag is enabled, we check that init takes the first position in
1446 // operands of payload.
1447 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1448  if (body->getOperations().size() != 2)
1449  return nullptr;
1450  Operation &payload = body->getOperations().front();
1451  assert(isa<YieldOp>(body->getOperations().back()));
1452 
1453  if (payload.getNumOperands() == 0 ||
1454  payload.getNumOperands() != body->getNumArguments())
1455  return nullptr;
1456  if (initFirst) {
1457  // check init
1458  if (payload.getOperands().back() != body->getArgument(0))
1459  return nullptr;
1460  // check rest
1461  for (const auto &[operand, bbArg] :
1462  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1463  if (bbArg != operand)
1464  return nullptr;
1465  }
1466  } else {
1467  for (const auto &[operand, bbArg] :
1468  llvm::zip(payload.getOperands(), body->getArguments())) {
1469  if (bbArg != operand)
1470  return nullptr;
1471  }
1472  }
1473  return &payload;
1474 }
1475 
1476 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1477  SmallVector<StringRef> elidedAttrs;
1478  std::string attrToElide;
1479  p << " { " << payloadOp->getName().getStringRef();
1480  for (const auto &attr : payloadOp->getAttrs()) {
1481  auto fastAttr =
1482  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1483  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1484  attrToElide = attr.getName().str();
1485  elidedAttrs.push_back(attrToElide);
1486  break;
1487  }
1488  }
1489  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1490  p << " }";
1491 }
1492 
1493 void MapOp::print(OpAsmPrinter &p) {
1494  Block *mapper = getBody();
1495  Operation *payloadOp = findPayloadOp(mapper);
1496  if (payloadOp) {
1497  printShortForm(p, payloadOp);
1498  }
1499 
1500  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1501  p.printOptionalAttrDict((*this)->getAttrs());
1502 
1503  if (!payloadOp) {
1504  // Print region if the payload op was not detected.
1505  p.increaseIndent();
1506  p.printNewline();
1507  p << "(";
1508  llvm::interleaveComma(mapper->getArguments(), p,
1509  [&](auto arg) { p.printRegionArgument(arg); });
1510  p << ") ";
1511 
1512  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1513  p.decreaseIndent();
1514  }
1515 }
1516 
1517 LogicalResult MapOp::verify() {
1518  auto *bodyBlock = getBody();
1519  auto blockArgs = bodyBlock->getArguments();
1520 
1521  // Checks if the number of `inputs` match the arity of the `mapper` region.
1522  if (getInputs().size() != blockArgs.size())
1523  return emitOpError() << "expects number of operands to match the arity of "
1524  "mapper, but got: "
1525  << getInputs().size() << " and " << blockArgs.size();
1526 
1527  // The parameters of mapper should all match the element type of inputs.
1528  for (const auto &[bbArgType, inputArg] :
1529  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1530  auto inputElemType =
1531  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1532  if (bbArgType != inputElemType) {
1533  return emitOpError() << "expected element type of input " << inputElemType
1534  << " to match bbArg type " << bbArgType;
1535  }
1536  }
1537 
1538  // The shape of each input must match the shape of the output.
1539  auto outputShape = getInit().getType().getShape();
1540  for (Type inputArgType : TypeRange{getInputs()}) {
1541  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1542  if (inputElemShape != outputShape) {
1543  return emitOpError() << "expected shape of input (" << inputElemShape
1544  << ") to match shape of output (" << outputShape
1545  << ")";
1546  }
1547  }
1548 
1549  return success();
1550 }
1551 
1552 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1553  int64_t rank = getInit().getType().getRank();
1554  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1555 }
1556 
1557 ArrayAttr MapOp::getIndexingMaps() {
1558  Builder builder(getContext());
1559  int64_t rank = getInit().getType().getRank();
1560  int64_t numIndexingMaps = getOperands().size();
1562  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1563 }
1564 
1565 void MapOp::getEffects(
1567  &effects) {
1568  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1569 }
1570 
1571 Speculation::Speculatability MapOp::getSpeculatability() {
1572  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1573 }
1574 
1575 //===----------------------------------------------------------------------===//
1576 // ReduceOp
1577 //===----------------------------------------------------------------------===//
1578 
1579 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1580  OpAsmSetValueNameFn setNameFn) {
1581  for (Value v : getRegionInputArgs())
1582  setNameFn(v, "in");
1583  for (Value v : getRegionOutputArgs())
1584  setNameFn(v, "init");
1585 }
1586 
1587 void ReduceOp::getAsmResultNames(
1588  function_ref<void(Value, StringRef)> setNameFn) {
1589  if (!getResults().empty())
1590  setNameFn(getResults().front(), "reduced");
1591 }
1592 
1593 void ReduceOp::build(
1594  OpBuilder &builder, OperationState &result, ValueRange inputs,
1595  ValueRange inits, ArrayRef<int64_t> dimensions,
1596  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1597  ArrayRef<NamedAttribute> attributes) {
1598  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1599  result.addAttributes(attributes);
1600 
1601  // Add output types for `RankedTensorType` output arguments.
1602  for (Value init : inits) {
1603  Type initType = init.getType();
1604  if (llvm::isa<RankedTensorType>(initType))
1605  result.addTypes(initType);
1606  }
1607 
1608  if (bodyBuild)
1609  buildGenericRegion(builder, result.location, *result.regions.front(),
1610  inputs, inits, bodyBuild);
1611 }
1612 
1613 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1614  int64_t inputRank =
1615  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1616  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1617  utils::IteratorType::parallel);
1618  for (int64_t reductionDim : getDimensions())
1619  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1620  return iteratorTypes;
1621 }
1622 
1623 ArrayAttr ReduceOp::getIndexingMaps() {
1624  int64_t inputRank =
1625  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1626  SmallVector<AffineMap> affineMaps(
1627  getNumDpsInputs(),
1629  AffineMap resultMap =
1631  .dropResults(getDimensions());
1632  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1633  affineMaps.push_back(resultMap);
1634  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1635 }
1636 
1637 void ReduceOp::getEffects(
1639  &effects) {
1640  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1641 }
1642 
1643 Speculation::Speculatability ReduceOp::getSpeculatability() {
1644  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1645 }
1646 
1647 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1648  NamedAttrList &attributes,
1649  StringRef attributeName) {
1650  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1651  return failure();
1652 
1653  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1654  return success();
1655 }
1656 
1657 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1658  std::optional<OperationName> payloadOpName;
1659  NamedAttrList payloadOpAttrs;
1660  if (succeeded(parser.parseOptionalLBrace())) {
1661  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1662  if (failed(operationName))
1663  return failure();
1664  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1665  return failure();
1666  payloadOpName = operationName.value();
1667  if (parser.parseRBrace())
1668  return failure();
1669  }
1670 
1671  if (parseDstStyleOp(
1672  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1673  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1674  }))
1675  return failure();
1676 
1677  if (payloadOpName.has_value()) {
1678  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1679  ArrayRef(result.operands), /*initFirst=*/true);
1680  } else {
1682  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1683  /*allowType=*/true, /*allowAttrs=*/true)) {
1684  return failure();
1685  }
1686 
1687  Region *body = result.addRegion();
1688  if (parser.parseRegion(*body, regionArgs))
1689  return failure();
1690  }
1691 
1692  return success();
1693 }
1694 
1695 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1696  ArrayRef<int64_t> attributeValue) {
1697  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1698 }
1699 
1700 void ReduceOp::print(OpAsmPrinter &p) {
1701  Block *mapper = getBody();
1702  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1703  if (payloadOp) {
1704  printShortForm(p, payloadOp);
1705  }
1706 
1707  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1708  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1709  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1710  if (!payloadOp) {
1711  // Print region if the payload op was not detected.
1712  p.increaseIndent();
1713  p.printNewline();
1714  p << "(";
1715  llvm::interleaveComma(mapper->getArguments(), p,
1716  [&](auto arg) { p.printRegionArgument(arg); });
1717  p << ") ";
1718 
1719  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1720  p.decreaseIndent();
1721  }
1722 }
1723 
1724 LogicalResult ReduceOp::verify() {
1725  ArrayRef<int64_t> dimensionsRef = getDimensions();
1726 
1727  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1728  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1729  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1730  return emitOpError() << "expects all inputs to have the same shapes. "
1731  "Shape at input-index "
1732  << i
1733  << " is not equal to the shape at input-index 0.";
1734  }
1735  }
1736  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1737  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1738  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1739  return emitOpError() << "expects all outputs to have the same shapes. "
1740  "Shape at output-index "
1741  << i
1742  << " is not equal to the shape at output-index 0.";
1743  }
1744  }
1745  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1746  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1747 
1748  DenseSet<int64_t> dimensionsToReduce;
1749  for (int64_t dimension : dimensionsRef) {
1750  if (dimension < 0 || dimension >= inputType.getRank()) {
1751  return emitOpError()
1752  << "dimensions for reduction should be in the range [0, "
1753  << inputType.getRank() - 1 << "].";
1754  }
1755  dimensionsToReduce.insert(dimension);
1756  }
1757 
1758  auto inputDims = inputType.getShape();
1759  auto initDims = initType.getShape();
1760 
1761  // Input dimensions that will be left after the reduction.
1762  SmallVector<int64_t> reducedInputDims;
1763  for (const auto &en : llvm::enumerate(inputDims)) {
1764  if (!dimensionsToReduce.count(en.index()))
1765  reducedInputDims.push_back(en.value());
1766  }
1767 
1768  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1769  return emitOpError() << "number of dimensions after reduction "
1770  << reducedInputDims.size()
1771  << " doesn't match the init rank "
1772  << initType.getRank();
1773  }
1774 
1775  if (reducedInputDims != initDims)
1776  return emitOpError() << "init dimensions [" << initDims
1777  << "] doesn't match input dimensions after reduction ["
1778  << reducedInputDims << "]";
1779 
1780  Block *block = getBody();
1781  if (block->getNumArguments() != this->getNumOperands())
1782  return emitOpError()
1783  << "mismatching number of operands and block arguments";
1784 
1785  // Check that the first block arguments match the element type of the inputs.
1786  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1787  Type inputElementType =
1788  llvm::cast<ShapedType>(input.getType()).getElementType();
1789  if (inputElementType != bbArg.getType())
1790  return emitOpError()
1791  << "input element type " << inputElementType
1792  << " does not match corresponding block argument type "
1793  << bbArg.getType();
1794  }
1795 
1796  // Check that the last block arguments match the element type of the outputs.
1797  for (auto [output, bbArg] : llvm::zip(
1798  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1799  auto outputElementType =
1800  llvm::cast<ShapedType>(output.getType()).getElementType();
1801  if (outputElementType != bbArg.getType())
1802  return emitOpError()
1803  << "output element type " << outputElementType
1804  << " does not match corresponding block argument type "
1805  << bbArg.getType();
1806  }
1807  return success();
1808 }
1809 
1810 //===----------------------------------------------------------------------===//
1811 // TransposeOp
1812 //===----------------------------------------------------------------------===//
1813 
1814 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1815  Region &region, ValueRange inputs,
1816  ValueRange outputs) {
1817  buildGenericRegion(builder, loc, region, inputs, outputs,
1818  [](OpBuilder &b, Location loc, ValueRange args) {
1819  if (!args.empty())
1820  b.create<linalg::YieldOp>(loc, args[0]);
1821  });
1822 }
1823 
1824 void TransposeOp::build(::mlir::OpBuilder &builder,
1825  ::mlir::OperationState &result, Value input, Value init,
1826  DenseI64ArrayAttr permutation,
1827  ArrayRef<NamedAttribute> attributes) {
1828  result.addOperands(input);
1829  result.addOperands(init);
1830  result.addAttribute(getPermutationAttrName(result.name), permutation);
1831  result.addAttributes(attributes);
1832 
1833  // Add output types for `RankedTensorType` output arguments.
1834  Type initType = init.getType();
1835  if (llvm::isa<RankedTensorType>(initType))
1836  result.addTypes(initType);
1837 
1838  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1839  init);
1840 }
1841 
1842 void TransposeOp::build(::mlir::OpBuilder &builder,
1843  ::mlir::OperationState &result, Value input, Value init,
1844  ArrayRef<int64_t> permutation,
1845  ArrayRef<NamedAttribute> attributes) {
1846  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1847  attributes);
1848 }
1849 
1850 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1851  if (failed(parseDstStyleOp(
1852  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1853  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1854  })))
1855  return failure();
1856 
1857  OpBuilder builder(parser.getContext());
1858  buildIdentityRegion(builder, result.location, *result.addRegion(),
1859  /*inputs=*/result.operands,
1860  /*outputs=*/{});
1861  return success();
1862 }
1863 
1864 void TransposeOp::getAsmResultNames(
1865  function_ref<void(Value, StringRef)> setNameFn) {
1866  if (!getResults().empty())
1867  setNameFn(getResults().front(), "transposed");
1868 }
1869 
1871  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1872  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1873  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1874 }
1875 
1876 LogicalResult TransposeOp::verify() {
1877  ArrayRef<int64_t> permutationRef = getPermutation();
1878 
1879  if (!isPermutationVector(permutationRef))
1880  return emitOpError("permutation is not valid");
1881 
1882  auto inputType = getInput().getType();
1883  auto initType = getInit().getType();
1884 
1885  int64_t rank = inputType.getRank();
1886 
1887  if (rank != initType.getRank())
1888  return emitOpError() << "input rank " << rank
1889  << " does not match init rank " << initType.getRank();
1890 
1891  if (rank != static_cast<int64_t>(permutationRef.size()))
1892  return emitOpError() << "size of permutation " << permutationRef.size()
1893  << " does not match the argument rank " << rank;
1894 
1895  auto inputDims = inputType.getShape();
1896  auto initDims = initType.getShape();
1897 
1898  for (int64_t i = 0; i < rank; ++i) {
1899  int64_t inputDim = inputDims[permutationRef[i]];
1900  int64_t initDim = initDims[i];
1901 
1902  if (inputDim != initDim) {
1903  return emitOpError() << "dim(result, " << i << ") = " << initDim
1904  << " doesn't match dim(input, permutation[" << i
1905  << "]) = " << inputDim;
1906  }
1907  }
1908 
1909  return success();
1910 }
1911 
1912 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1913  int64_t rank = getInit().getType().getRank();
1914  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1915 }
1916 
1917 ArrayAttr TransposeOp::getIndexingMaps() {
1918  Builder builder(getContext());
1919  int64_t rank = getInit().getType().getRank();
1920  return builder.getAffineMapArrayAttr(
1922  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1923  builder.getMultiDimIdentityMap(rank)});
1924 }
1925 
1926 void TransposeOp::getEffects(
1928  &effects) {
1929  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1930 }
1931 
1932 Speculation::Speculatability TransposeOp::getSpeculatability() {
1933  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1934 }
1935 
1936 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1938  // Only the tensor type is supported.
1939  if (!isa<TensorType>(getInput().getType()))
1940  return failure();
1941 
1942  // Single dimension transpose.
1943  if (getPermutation().size() == 0) {
1944  result.push_back(getInput());
1945  return success();
1946  }
1947  // Identity permutation.
1948  if (isIdentityPermutation(getPermutation())) {
1949  result.push_back(getInput());
1950  return success();
1951  }
1952 
1953  return failure();
1954 }
1955 
1956 /// Fold transpose with transpose.
1957 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1959 
1960  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1961  PatternRewriter &rewriter) const override {
1962  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1963  if (!defTransposeOp)
1964  return failure();
1965  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1966  ArrayRef<int64_t> perms = transposeOp.getPermutation();
1967  SmallVector<int64_t> foldedPerms;
1968  foldedPerms.reserve(perms.size());
1969  for (int64_t perm : perms)
1970  foldedPerms.push_back(defPerms[perm]);
1971 
1972  rewriter.replaceOpWithNewOp<TransposeOp>(
1973  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1974  foldedPerms);
1975  return success();
1976  }
1977 };
1978 
1979 /// This pattern canonicalize transpose by swapping the order of
1980 /// broadcast and transpose:
1981 /// transpose(broadcast(input)) -> broadcast(transpose(input))
1982 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
1984 
1985  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1986  PatternRewriter &rewriter) const override {
1987  Value input = transposeOp.getInput();
1988  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
1989  if (!input.hasOneUse() || !broadcastOp)
1990  return failure();
1991 
1992  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
1993  ArrayRef<int64_t> perms = transposeOp.getPermutation();
1994 
1995  // Get new perms and new dimensions.
1996  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
1997  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
1998  SmallVector<int64_t> resultDimensions;
1999  unsigned dimensionSize = dimensions.size();
2000  for (unsigned i = 0; i < dimensionSize; ++i)
2001  resultDimensions.push_back(invertPerm[dimensions[i]]);
2002 
2003  // Create transpose result.
2004  Value broadcastInput = broadcastOp.getInput();
2005  Location loc = transposeOp.getLoc();
2006  MLIRContext *ctx = transposeOp.getContext();
2008  auto broadcastInputTy =
2009  mlir::cast<RankedTensorType>(broadcastInput.getType());
2010  unsigned inputRank = broadcastInputTy.getRank();
2011  for (unsigned i = 0; i < inputRank; ++i) {
2012  if (broadcastInputTy.isDynamicDim(i)) {
2013  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2014  ->getResult(0));
2015  } else {
2016  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2017  broadcastInputTy.getDimSize(i)));
2018  }
2019  }
2020  SmallVector<OpFoldResult> transposeResultShapes =
2021  applyPermutation(dims, resultPerms);
2022  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2023  transposeOp.getLoc(), transposeResultShapes,
2024  broadcastInputTy.getElementType());
2025 
2026  // Create broadcast(transpose(input)).
2027  Value transposeResult =
2028  rewriter
2029  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2030  resultPerms)
2031  ->getResult(0);
2032  rewriter.replaceOpWithNewOp<BroadcastOp>(
2033  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2034  return success();
2035  }
2036 };
2037 
2038 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2039  MLIRContext *context) {
2041 }
2042 
2043 //===----------------------------------------------------------------------===//
2044 // BroadcastOp
2045 //===----------------------------------------------------------------------===//
2046 
2047 void BroadcastOp::build(::mlir::OpBuilder &builder,
2048  ::mlir::OperationState &result, Value input, Value init,
2049  DenseI64ArrayAttr dimensions,
2050  ArrayRef<NamedAttribute> attributes) {
2051  result.addOperands(input);
2052  result.addOperands(init);
2053  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2054  result.addAttributes(attributes);
2055 
2056  // Add output types for `RankedTensorType` output arguments.
2057  Type initType = init.getType();
2058  if (llvm::isa<RankedTensorType>(initType))
2059  result.addTypes(initType);
2060 
2061  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2062  init);
2063 }
2064 
2065 void BroadcastOp::build(::mlir::OpBuilder &builder,
2066  ::mlir::OperationState &result, Value input, Value init,
2067  ArrayRef<int64_t> dimensions,
2068  ArrayRef<NamedAttribute> attributes) {
2069  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2070  attributes);
2071 }
2072 
2073 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2074  if (failed(parseDstStyleOp(
2075  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2076  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2077  })))
2078  return failure();
2079 
2080  OpBuilder builder(parser.getContext());
2081  buildIdentityRegion(builder, result.location, *result.addRegion(),
2082  /*inputs=*/result.operands,
2083  /*outputs=*/{});
2084  return success();
2085 }
2086 
2087 void BroadcastOp::getAsmResultNames(
2088  function_ref<void(Value, StringRef)> setNameFn) {
2089  if (!getResults().empty())
2090  setNameFn(getResults().front(), "broadcasted");
2091 }
2092 
2094  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2095  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2096  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2097 }
2098 
2099 LogicalResult BroadcastOp::verify() {
2100  ArrayRef<int64_t> dimensionsRef = getDimensions();
2101 
2102  auto inputType = getInput().getType();
2103  auto initType = getInit().getType();
2104 
2105  int64_t inputRank = inputType.getRank();
2106  int64_t initRank = initType.getRank();
2107 
2108  auto inputShape = inputType.getShape();
2109  auto initShape = initType.getShape();
2110 
2111  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2112  return emitOpError() << "input rank plus added dimensions does not "
2113  "match init rank. input rank: "
2114  << inputRank
2115  << ", dimensions size: " << dimensionsRef.size()
2116  << ", init rank: " << initRank;
2117 
2118  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2119  if (dim < 0 || dim >= initRank)
2120  return emitOpError() << "dimension " << idx
2121  << " is out of range. expected range: [0, "
2122  << initRank - 1 << "], got: " << dim;
2123  }
2124 
2125  // Mapping from input dims to init dims.
2126  SmallVector<int64_t> dimMap;
2127  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2128  if (!llvm::is_contained(dimensionsRef, dim))
2129  dimMap.push_back(dim);
2130  }
2131 
2132  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2133  // This dimensions is mapped from the input. Init and input dims should
2134  // match.
2135  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2136  return emitOpError() << "input dim " << inputDimIdx
2137  << " should match init dim " << initDimIdx
2138  << ". input: " << inputShape[inputDimIdx]
2139  << ", init: " << initShape[initDimIdx];
2140  }
2141 
2142  return success();
2143 }
2144 
2145 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2146  int64_t rank = getInit().getType().getRank();
2147  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2148 }
2149 
2150 ArrayAttr BroadcastOp::getIndexingMaps() {
2151  Builder builder(getContext());
2152  int64_t rank = getInit().getType().getRank();
2153  return builder.getAffineMapArrayAttr(
2154  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2155  builder.getMultiDimIdentityMap(rank)});
2156 }
2157 
2158 void BroadcastOp::getEffects(
2160  &effects) {
2161  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2162 }
2163 
2164 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2165  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2166 }
2167 
2168 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2169  MLIRContext *context) {
2170  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // YieldOp
2175 //===----------------------------------------------------------------------===//
2176 
2178  if (getNumOperands() > 0)
2179  p << ' ' << getOperands();
2180  p.printOptionalAttrDict((*this)->getAttrs());
2181  if (getNumOperands() > 0)
2182  p << " : " << getOperandTypes();
2183 }
2184 
2185 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2187  SmallVector<Type, 2> types;
2188  SMLoc loc = parser.getCurrentLocation();
2189  return failure(parser.parseOperandList(opInfo) ||
2190  parser.parseOptionalAttrDict(result.attributes) ||
2191  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2192  parser.resolveOperands(opInfo, types, loc, result.operands));
2193 }
2194 
2195 // Check the operand number and types must match the element types of the
2196 // LinalgOp interface's shaped operands.
2197 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2198  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2199  return op.emitOpError("expected number of yield values (")
2200  << op.getNumOperands()
2201  << ") to match the number of inits / outs operands of the enclosing "
2202  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2203 
2204  for (OpOperand &opOperand : op->getOpOperands()) {
2205  OpOperand *outputOperand =
2206  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2207  Type elementType = outputOperand->get().getType();
2208  if (isa<MemRefType, RankedTensorType>(elementType))
2209  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2210  if (opOperand.get().getType() != elementType)
2211  return op.emitOpError("type of yield operand ")
2212  << (opOperand.getOperandNumber() + 1) << " ("
2213  << opOperand.get().getType() << ") doesn't match "
2214  << "the element type of the enclosing linalg.generic op ("
2215  << elementType << ")";
2216  }
2217  return success();
2218 }
2219 
2220 LogicalResult linalg::YieldOp::verify() {
2221  auto *parentOp = (*this)->getParentOp();
2222  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2223  return emitOpError("expected single non-empty parent region");
2224 
2225  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2226  return verifyYield(*this, linalgOp);
2227 
2228  return emitOpError("expected parent op with LinalgOp interface");
2229 }
2230 
2231 //===----------------------------------------------------------------------===//
2232 // IndexOp
2233 //===----------------------------------------------------------------------===//
2234 
2235 LogicalResult IndexOp::verify() {
2236  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2237  if (!linalgOp)
2238  return emitOpError("expected parent op with LinalgOp interface");
2239  if (linalgOp.getNumLoops() <= getDim())
2240  return emitOpError("expected dim (")
2241  << getDim() << ") to be lower than the number of loops ("
2242  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2243  return success();
2244 }
2245 
2246 /////// Operations corresponding to library calls defined with Tablegen ////////
2247 
2248 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2249 
2250 #define GET_OP_CLASSES
2251 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2252 
2253 #define GET_OP_CLASSES
2254 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2255 
2256 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2257  unsigned rank,
2258  MLIRContext *context) {
2259  if (maybeMap)
2260  return *maybeMap;
2261  if (rank == 0)
2262  return AffineMap::get(context);
2263  return AffineMap::getMultiDimIdentityMap(rank, context);
2264 }
2265 
2267 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2268  MLIRContext *context) {
2270  res.reserve(num);
2271  for (unsigned i = 0; i < num; ++i)
2272  res.push_back(getAffineDimExpr(startIdx++, context));
2273  return res;
2274 }
2275 
2278  auto rangeA = llvm::make_range(a.begin(), a.end());
2279  auto rangeB = llvm::make_range(b.begin(), b.end());
2280  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2281  return llvm::to_vector<4>(concatRanges);
2282 }
2283 
2284 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2285  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2286  ss << "view";
2287  for (auto size : memref.getShape())
2288  if (size < 0)
2289  ss << "sx";
2290  else
2291  ss << size << "x";
2292  if (failed(appendMangledType(ss, memref.getElementType())))
2293  return failure();
2294  if (auto as = memref.getMemorySpace()) {
2295  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2296  ss << "as" << attr.getInt();
2297  else
2298  return failure();
2299  }
2300  return success();
2301  }
2302  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2303  ss << "vector";
2304  llvm::interleave(
2305  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2306  if (failed(appendMangledType(ss, vec.getElementType())))
2307  return failure();
2308  return success();
2309  }
2310  if (t.isSignlessIntOrIndexOrFloat()) {
2311  ss << t;
2312  return success();
2313  }
2314  return failure();
2315 }
2316 
2318  assert(isa<LinalgOp>(op));
2319  std::string name(op->getName().getStringRef().str());
2320  std::string fun = "";
2321  for (NamedAttribute kv : op->getAttrs()) {
2322  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2323  fun = stringifyEnum(ufa.getValue()).str() + "_";
2324  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2325  fun = stringifyEnum(bfa.getValue()).str() + "_";
2326  }
2327  }
2328  name.reserve(128);
2329  std::replace(name.begin(), name.end(), '.', '_');
2330  llvm::raw_string_ostream ss(name);
2331  ss << "_" << fun;
2332  for (Type t : op->getOperandTypes()) {
2333  if (failed(appendMangledType(ss, t)))
2334  return std::string();
2335  ss << "_";
2336  }
2337  name.pop_back();
2338  return name;
2339 }
2340 
2341 //===----------------------------------------------------------------------===//
2342 // Canonicalizers and Folders.
2343 //===----------------------------------------------------------------------===//
2344 
2345 namespace {
2346 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2348 
2349  LogicalResult matchAndRewrite(LinalgOp op,
2350  PatternRewriter &rewriter) const override {
2351  for (OpOperand &opOperand : op->getOpOperands()) {
2352  // Linalg "inputs" may be either tensor or memref type.
2353  // tensor<0xelt_type> is a convention that may not always mean
2354  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2355  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2356  if (!mt)
2357  continue;
2358  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2359  rewriter.eraseOp(op);
2360  return success();
2361  }
2362  }
2363  return failure();
2364  }
2365 };
2366 
2367 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2368 /// result that is more static than the linalg op.
2369 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2371 
2372  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2373  PatternRewriter &rewriter) const override {
2374  if (!tensor::canFoldIntoProducerOp(castOp))
2375  return failure();
2376 
2377  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2378  if (!linalgOp)
2379  return failure();
2380 
2381  // Cast can be in conditionally reachable region, if which case folding will
2382  // generate invalid code. Only conservatively fold ops in same block for
2383  // now.
2384  if (castOp->getBlock() != linalgOp->getBlock())
2385  return failure();
2386 
2387  OpBuilder::InsertionGuard guard(rewriter);
2388  rewriter.setInsertionPoint(linalgOp);
2389 
2390  Location loc = linalgOp.getLoc();
2391  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2392  unsigned resultNumber = resultValue.getResultNumber();
2393  auto resultType =
2394  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2395  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2396  // going from a more dynamic shape to a less dynamic shape. If the producer
2397  // for this cast, i.e. producer of the out operand, is also an operation
2398  // that folds with tensor.cast consumer (like this pattern), the cast will
2399  // continue to propagate as far up the stack as it can go.
2400  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2401  Value newOperand =
2402  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2403  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2404  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2405  linalgOp.getDpsInits().end());
2406  outputOperands[resultNumber] = newOperand;
2407  newOperands.append(outputOperands.begin(), outputOperands.end());
2408 
2409  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2410  linalgOp->result_type_end());
2411  resultTypes[resultNumber] = resultType;
2412  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2413 
2414  // Create a tensor.cast operation back to the original type.
2415  Value castBack = rewriter.create<tensor::CastOp>(
2416  loc, resultValue.getType(), newOp->getResult(resultNumber));
2417 
2418  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2419  results[resultNumber] = castBack;
2420  rewriter.replaceOp(linalgOp, results);
2421  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2422  return success();
2423  }
2424 };
2425 
2426 /// For each of the operand in `operands` this function maps the static sizes of
2427 /// dimensions to their affine dim expressions.
2428 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2429  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2430  for (OpOperand &opOperand : operands) {
2431  if (linalgOp.isScalar(&opOperand))
2432  continue;
2433  Value src = opOperand.get();
2434  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2435  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2436 
2437  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2438  // `tensor.cast` operation and source of the cast operation has a static
2439  // shape, then assign it to the `sourceShape`.
2440  auto *parentOp = src.getDefiningOp();
2441  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2442  if (parentOp) {
2443  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2444  Value castSource = castOp.getSource();
2445  auto castSourceType =
2446  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2447  if (castSourceType && castSourceType.hasStaticShape())
2448  sourceShape = castSourceType.getShape();
2449  }
2450  }
2451 
2452  // If the source shape's dimension has a static shape, map the affine dim
2453  // expression to the known static size.
2454  for (unsigned i = 0; i < sourceShape.size(); i++) {
2455  if (sourceType.isDynamicDim(i))
2456  continue;
2457  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2458  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2459  }
2460  }
2461 }
2462 
2463 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2464 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2465 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2466 /// change then `changeNeeded` is false and same operand is added in the
2467 /// `newOperands` list.
2468 static void createNewOperandWithStaticSizes(
2469  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2470  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2471  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2472  bool &changeNeeded) {
2473  Value src = opOperand->get();
2474  newOperands.push_back(src);
2475  if (linalgOp.isScalar(opOperand))
2476  return;
2477  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2478  Type resultType = sourceType;
2479  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2480  resultTypes.push_back(resultType);
2481  return;
2482  }
2483  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2484  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2485  SmallVector<int64_t> newShape;
2486  // If operand is updated with new shape, `newOperandNeeded` will be
2487  // true.
2488  bool newOperandNeeded = false;
2489  for (unsigned i = 0; i < sourceShape.size(); i++) {
2490  int64_t dimShape = sourceShape[i];
2491  AffineExpr dimExpr = sourceMap.getResult(i);
2492  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2493  newShape.push_back(dimShape);
2494  continue;
2495  }
2496  // Dimension has a dynamic shape and corresponding affine dim
2497  // expression is present in the map. So assign the size for the
2498  // given affine dim expression to the dimension.
2499  newShape.push_back(affineExprToSize[dimExpr]);
2500  newOperandNeeded = true;
2501  }
2502  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2503  if (newOperandNeeded) {
2504  changeNeeded = true;
2505  // Get the new operand value given its size and element type by
2506  // casting it.
2507  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2508  unsigned index = opOperand->getOperandNumber();
2509  newOperands[index] = newOperand;
2510  }
2511  if (linalgOp.isDpsInit(opOperand))
2512  resultTypes.push_back(resultType);
2513 }
2514 
2515 /// Static shapes for the operands can be inferred if any one of the operands
2516 /// have a static shape. This can be done by referring to the affine dim
2517 /// expressions for the operand.
2518 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2520 
2521  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2522  PatternRewriter &rewriter) const override {
2523  if (!linalgOp.hasPureTensorSemantics())
2524  return failure();
2525 
2526  // Maps must be projected permutations.
2527  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2528  return !map.isProjectedPermutation();
2529  }))
2530  return failure();
2531 
2532  // Maps affine dim expressions to the static size of that dimension.
2533  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2534  Location loc = linalgOp.getLoc();
2535 
2536  // For each of the affine dim expression, check if the size is known. If
2537  // known add that in the map.
2538  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2539 
2540  SmallVector<Value> newOperands;
2541  SmallVector<Type> resultTypes;
2542 
2543  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2544  // change in their types.
2545  bool changeNeeded = false;
2546  newOperands.reserve(linalgOp->getNumOperands());
2547  resultTypes.reserve(linalgOp.getNumDpsInits());
2548 
2549  // Iterate over all the operands and update the static sizes.
2550  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2551  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2552  affineExprToSize, linalgOp, newOperands,
2553  resultTypes, changeNeeded);
2554  }
2555 
2556  // If the generic op has all the required static information, no
2557  // canonicalization needed.
2558  if (!changeNeeded)
2559  return failure();
2560 
2561  // Clone op.
2562  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2563  SmallVector<Value> replacements;
2564  replacements.reserve(newOp->getNumResults());
2565  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2566  Value newResult = std::get<1>(it);
2567  Value oldResult = std::get<0>(it);
2568  Type newType = newResult.getType();
2569  Type oldType = oldResult.getType();
2570  replacements.push_back(
2571  (newType != oldType)
2572  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2573  : newResult);
2574  }
2575  rewriter.replaceOp(linalgOp, replacements);
2576  return success();
2577  }
2578 };
2579 
2580 } // namespace
2581 
2582 // All named ops canonicalizers and folders are auto-generated in the
2583 // .cpp.inc.
2584 
2585 //===----------------------------------------------------------------------===//
2586 // SoftmaxOp
2587 //===----------------------------------------------------------------------===//
2588 
2589 LogicalResult SoftmaxOp::verify() {
2590  ShapedType inputType = getInputOperandType();
2591  ShapedType outputType = getOutputOperandType();
2592 
2593  ArrayRef<int64_t> inputShape = inputType.getShape();
2594  ArrayRef<int64_t> outputShape = outputType.getShape();
2595  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2596  return emitOpError("incompatible output shape");
2597 
2598  int64_t inputRank = getInputOperandRank();
2599  int64_t dimension = getDimension();
2600  if ((dimension < 0) || (dimension >= inputRank))
2601  return emitOpError("incorrect dimension specified");
2602 
2603  return success();
2604 }
2605 
2606 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2607  int64_t operandRank = getInputOperandRank();
2608  SmallVector<Range> loopBounds(operandRank);
2609  Location loc = getLoc();
2610  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2611  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2612  Value source = getInput();
2613  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2614  loopBounds[dim].offset = zero;
2615  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2616  loopBounds[dim].stride = one;
2617  }
2618  return loopBounds;
2619 }
2620 
2621 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2622  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2623  utils::IteratorType::parallel);
2624  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2625  return iteratorTypes;
2626 }
2627 
2628 FailureOr<TilingResult>
2629 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2630  ArrayRef<OpFoldResult> offsets,
2631  ArrayRef<OpFoldResult> sizes) {
2632  int64_t rank = getInputOperandRank();
2633  auto oneAttr = builder.getI64IntegerAttr(1);
2634  SmallVector<OpFoldResult> strides(rank, oneAttr);
2635  SmallVector<Value> tiledOperands;
2636  Operation *inputSlice =
2637  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2638  if (!inputSlice) {
2639  return emitOpError("failed to compute input slice");
2640  }
2641  tiledOperands.emplace_back(inputSlice->getResult(0));
2642  Operation *outputSlice =
2643  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2644  if (!outputSlice) {
2645  return emitOpError("failed to compute output slice");
2646  }
2647  tiledOperands.emplace_back(outputSlice->getResult(0));
2648 
2649  SmallVector<Type, 4> resultTypes;
2650  if (hasPureTensorSemantics())
2651  resultTypes.push_back(tiledOperands[1].getType());
2652  Operation *tiledOp =
2653  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2654 
2655  return TilingResult{
2656  {tiledOp},
2657  SmallVector<Value>(tiledOp->getResults()),
2658  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2659 }
2660 
2661 LogicalResult SoftmaxOp::getResultTilePosition(
2662  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2663  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2664  SmallVector<OpFoldResult> &resultSizes) {
2665  if (resultNumber == 0) {
2666  resultOffsets.assign(offsets.begin(), offsets.end());
2667  resultSizes.assign(sizes.begin(), sizes.end());
2668  return success();
2669  }
2670  return failure();
2671 }
2672 
2673 // cast(dynamic) -> static.
2674 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2675  return memref::foldMemRefCast(*this);
2676 }
2677 
2678 LogicalResult
2680  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2682  Location loc = getOperation()->getLoc();
2683  IRRewriter rewriter(b);
2684  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2685  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2686  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2687  if (!outputShapedType.isDynamicDim(dim)) {
2688  // Static dim: Return IntegerAttr.
2689  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2690  } else {
2691  // Dynamic dim: Return Value.
2692  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2693  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2694  }
2695  }
2696  reifiedReturnShapes.emplace_back(std::move(shapes));
2697  return success();
2698 }
2699 
2700 void SoftmaxOp::getEffects(
2702  &effects) {
2703  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2704  if (!llvm::isa<MemRefType>(operand.getType()))
2705  continue;
2706  effects.emplace_back(MemoryEffects::Read::get(),
2707  &getOperation()->getOpOperand(index), /*stage=*/0,
2708  /*effectOnFullRegion=*/true,
2710  }
2711 
2712  for (OpOperand &operand : getDpsInitsMutable()) {
2713  if (!llvm::isa<MemRefType>(operand.get().getType()))
2714  continue;
2715  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2716  /*effectOnFullRegion=*/true,
2718  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2719  /*effectOnFullRegion=*/true,
2721  }
2722 }
2723 
2724 // Helper functions for softmax decomposition.
2725 // @{
2726 
2727 // Helper function to produce the iterator types (reduction or parallel) and
2728 // affine maps for the iterators used in the decomposition of softmax.
2729 // This method creates:
2730 // If allParallel == true:
2731 // - iterator type: {parallel, ..., parallel}
2732 // - affine maps:
2733 // -- identity with inputRank dimensions.
2734 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2735 // where N == inputRank.
2736 //
2737 // If allParallel == false:
2738 // - iterator type at dim(i) == parallel for i != \p dim and
2739 // dim(dim) == reduction.
2740 // - affine map:
2741 // -- identity with inputRank dimensions.
2742 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2743 // where N == inputRank.
2744 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2746  int64_t dim, bool allParallel = false) {
2747  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2748  utils::IteratorType::parallel);
2749  if (!allParallel)
2750  iteratorTypes[dim] = utils::IteratorType::reduction;
2751  MLIRContext *ctxt = builder.getContext();
2752  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2753  SmallVector<AffineExpr, 2> affineExprs;
2754  for (int i = 0; i < inputRank; i++) {
2755  if (i != dim)
2756  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2757  }
2758  auto reductionMap =
2759  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2760  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2761  return std::make_tuple(iteratorTypes, indexingMaps);
2762 }
2763 
2764 // Helper function to produce a linalg.generic that computes a reduction on
2765 // dimension \p dim with the operation type \p T.
2766 template <typename T>
2767 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2768  int64_t dim) {
2769  auto inputType = cast<ShapedType>(input.getType());
2770  ArrayRef<int64_t> inputShape = inputType.getShape();
2771  int64_t inputRank = inputShape.size();
2772  auto [iteratorTypes, indexingMaps] =
2773  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2774  assert(indexingMaps.size() == 2 &&
2775  "We should have two maps: 1 for the input, 1 for the output");
2776  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2777 
2778  auto genericOp = builder.create<linalg::GenericOp>(
2779  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2780  [&](OpBuilder &b, Location loc, ValueRange args) {
2781  Value result = b.create<T>(loc, args[0], args[1]);
2782  b.create<linalg::YieldOp>(loc, result);
2783  });
2784  return genericOp.getResult(0);
2785 }
2786 
2787 /// Produce a linalg generic that computes the second step of the softmax
2788 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2789 /// on dimension \p dim.
2790 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2791  Value max, Value output, int64_t dim) {
2792  auto inputType = cast<ShapedType>(input.getType());
2793  ArrayRef<int64_t> inputShape = inputType.getShape();
2794  int64_t inputRank = inputShape.size();
2795  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2796  builder, inputRank, dim, /*allParallel=*/true);
2797  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2798  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2799  // Add the affine map for the output argument.
2800  indexingMaps.push_back(indexingMaps[0]);
2801  auto genericOp = builder.create<linalg::GenericOp>(
2802  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2803  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2804  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2805  Value result = b.create<math::ExpOp>(loc, diff);
2806  b.create<linalg::YieldOp>(loc, result);
2807  });
2808  return genericOp.getResult(0);
2809 }
2810 
2811 /// Produce a linalg generic that computes the final step of the softmax
2812 /// decomposition.
2813 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2814 /// yield n / d
2815 /// }
2816 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2817  Value denominator, Value output, int64_t dim) {
2818  auto inputType = cast<ShapedType>(numerator.getType());
2819  ArrayRef<int64_t> inputShape = inputType.getShape();
2820  int64_t inputRank = inputShape.size();
2821  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2822  builder, inputRank, dim, /*allParallel=*/true);
2823  assert(indexingMaps.size() == 2 &&
2824  "We should have one map for each input (2)");
2825  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2826  // Add the affine map for the output tensor.
2827  indexingMaps.push_back(indexingMaps[0]);
2828  auto genericOp = builder.create<linalg::GenericOp>(
2829  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2830  indexingMaps, iteratorTypes,
2831  [&](OpBuilder &b, Location loc, ValueRange args) {
2832  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2833  b.create<linalg::YieldOp>(loc, result);
2834  });
2835  return genericOp.getResult(0);
2836 }
2837 // @} End helper functions for softmax decomposition.
2838 
2839 /// Given an N-dimensional tensor x, this method converts
2840 /// softmax(x) to the following sequence of operations:
2841 ///
2842 /// 1. Compute the max of x along dimension d. This results
2843 /// in a N-1 dimensional tensor m.
2844 /// m = max(x, dim = d)
2845 ///
2846 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2847 /// a N dimensional tensor z.
2848 /// z = exp(x - m)
2849 ///
2850 /// 3. Compute the sum of z along dimension d. This results in
2851 /// a N-1 dimensional tensor l.
2852 /// l = sum(z, dim = d)
2853 ///
2854 /// 4. Divide z and l. This gives the N-dimensional softmax.
2855 /// softmax = z / l
2856 ///
2857 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2858  OpBuilder::InsertionGuard guard(b);
2859  b.setInsertionPoint(*this);
2860  Location loc = getLoc();
2861  Value input = getInput();
2862  ShapedType inputType = getInputOperandType();
2863  Type elementType = inputType.getElementType();
2864  int64_t reductionDim = getDimension();
2865  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2866  Value output = getOutput();
2867  dims.erase(dims.begin() + reductionDim);
2868  // Step 1: Compute max along dim.
2869  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2870  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2871  elementType, b, loc,
2872  /*useOnlyFiniteValue=*/true);
2873  Value neutralForMaxFInit =
2874  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2875  .result();
2876  Value max =
2877  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2878 
2879  // Step 2: Subtract max from input and exponentiate.
2880  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2881 
2882  // Step 3: Compute sum along dim.
2883  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2884  b, loc, /*useOnlyFiniteValue=*/true);
2885  Value zeroInit =
2886  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2887  Value denominator =
2888  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2889 
2890  // Step 4: Compute softmax.
2891  Value result =
2892  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2893  return SmallVector<Value>{result};
2894 }
2895 
2896 //===----------------------------------------------------------------------===//
2897 // WinogradFilterTransformOp
2898 //===----------------------------------------------------------------------===//
2899 
2900 LogicalResult WinogradFilterTransformOp::verify() {
2901  auto filterType = cast<ShapedType>(getFilter().getType());
2902  ArrayRef<int64_t> filterShape = filterType.getShape();
2903  int64_t filterH = filterShape[getFilterHDim()];
2904  int64_t filterW = filterShape[getFilterWDim()];
2905  int64_t r = getR();
2906  int64_t m = getM();
2907 
2908  if (filterH != r && filterH != 1)
2909  return emitOpError("expect filter height either equals to r or 1");
2910  if (filterW != r && filterW != 1)
2911  return emitOpError("expect filter width either equals to r or 1");
2912  if (filterH == 1 && filterW == 1)
2913  return emitOpError("expect either filter height or width equals to r");
2914 
2915  SmallVector<int64_t> expectedOutputShape;
2916  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2917  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2918  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2919  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2920 
2921  auto outputType = cast<ShapedType>(getOutput().getType());
2922  ArrayRef<int64_t> outputShape = outputType.getShape();
2923  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2924  return emitOpError("the output shape is not expected");
2925  }
2926  return success();
2927 }
2928 
2930 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2931  Location loc = getLoc();
2932  IntegerAttr zeroAttr = builder.getIndexAttr(0);
2933  IntegerAttr oneAttr = builder.getIndexAttr(1);
2934  Value filter = getFilter();
2935  int64_t filterRank = getFilterOperandRank();
2936  SmallVector<Range> loopBounds(filterRank);
2937  for (unsigned dim = 0; dim < filterRank; ++dim) {
2938  loopBounds[dim].offset = zeroAttr;
2939  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2940  loopBounds[dim].stride = oneAttr;
2941  }
2942  return loopBounds;
2943 }
2944 
2946 WinogradFilterTransformOp::getLoopIteratorTypes() {
2947  int64_t filterRank = getFilterOperandRank();
2948  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2949  utils::IteratorType::parallel);
2950  return iteratorTypes;
2951 }
2952 
2953 LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2954  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2955  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2956  SmallVector<OpFoldResult> &resultSizes) {
2957  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2958  ShapedType filterType = getFilterOperandType();
2959  ArrayRef<int64_t> filterShape = filterType.getShape();
2960  int64_t filterH = filterShape[getFilterHDim()];
2961  int64_t filterW = filterShape[getFilterWDim()];
2962  int64_t m = getM();
2963  int64_t r = getR();
2964  int64_t alpha = m + r - 1;
2965  int64_t alphaH = filterH != 1 ? alpha : 1;
2966  int64_t alphaW = filterW != 1 ? alpha : 1;
2967  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
2968  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
2969 
2970  resultOffsets.append(
2971  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
2972  resultSizes.append(
2973  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
2974 
2975  return success();
2976 }
2977 
2978 /// Implement tiling for winograd_filter_transform
2979 /// The input of winograd_filter_transform is (F, KH, KW, C).
2980 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
2981 /// Users can specify the tile sizes of F and C.
2982 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
2983 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
2984 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
2985  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
2986  ArrayRef<OpFoldResult> sizes) {
2987  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
2988  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2989  ShapedType filterType = getFilterOperandType();
2990  ArrayRef<int64_t> filterShape = filterType.getShape();
2991  int64_t filterH = filterShape[getFilterHDim()];
2992  int64_t filterW = filterShape[getFilterWDim()];
2993  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
2994  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
2995  SmallVector<Value> tiledOperands;
2996  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2997 
2998  sliceOffsets.append(
2999  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3000  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3001  sizes[getFilterCDim()]});
3002  int64_t filterRank = getFilterOperandRank();
3003  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3004  Location loc = getLoc();
3005  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3006  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3007  tiledOperands.emplace_back(filterSlice);
3008 
3009  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3010  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3011  resultSizes)))
3012  return failure();
3013 
3014  int64_t outputRank = getOutputOperandRank();
3015  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3016  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3017  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3018  tiledOperands.emplace_back(outputSlice);
3019 
3020  SmallVector<Type> resultTypes;
3021  resultTypes.push_back(tiledOperands[1].getType());
3022  Operation *tiledOp =
3023  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3024 
3025  return TilingResult{
3026  {tiledOp},
3027  SmallVector<Value>(tiledOp->getResults()),
3028  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3029 }
3030 
3031 //===----------------------------------------------------------------------===//
3032 // WinogradInputTransformOp
3033 //===----------------------------------------------------------------------===//
3034 
3035 LogicalResult WinogradInputTransformOp::verify() {
3036  auto inputType = cast<ShapedType>(getInput().getType());
3037  ArrayRef<int64_t> inputShape = inputType.getShape();
3038  int64_t inputH = inputShape[getInputHDim()];
3039  int64_t inputW = inputShape[getInputWDim()];
3040  int m = getM();
3041  int r = getR();
3042  int64_t tileSize = m + r - 1;
3043  bool leftTransform = inputH != 1;
3044  bool rightTransform = inputW != 1;
3045 
3046  SmallVector<int64_t> expectedOutputShape(6, inputH);
3047  if (ShapedType::isDynamic(inputH)) {
3048  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3049  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3050  } else {
3051  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3052  expectedOutputShape[getOutputTileHDim()] =
3053  leftTransform ? (inputH - (r - 1)) / m : 1;
3054  }
3055  if (ShapedType::isDynamic(inputW)) {
3056  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3057  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3058  } else {
3059  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3060  expectedOutputShape[getOutputTileWDim()] =
3061  rightTransform ? (inputW - (r - 1)) / m : 1;
3062  }
3063  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3064  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3065 
3066  auto outputType = cast<ShapedType>(getOutput().getType());
3067  ArrayRef<int64_t> outputShape = outputType.getShape();
3068  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3069  return emitOpError("the output shape is not expected");
3070  }
3071  return success();
3072 }
3073 
3075 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3076  Location loc = getLoc();
3077  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3078  IntegerAttr oneAttr = builder.getIndexAttr(1);
3079  Value output = getOutput();
3080  int64_t outputRank = getOutputOperandRank();
3081  SmallVector<Range> loopBounds(outputRank);
3082  for (unsigned dim = 0; dim < outputRank; ++dim) {
3083  loopBounds[dim].offset = zeroAttr;
3084  // alphaH, alphaW, tileH, tileW, N, C
3085  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3086  loopBounds[dim].stride = oneAttr;
3087  }
3088  return loopBounds;
3089 }
3090 
3092 WinogradInputTransformOp::getLoopIteratorTypes() {
3093  int64_t outputRank = getOutputOperandRank();
3094  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3095  utils::IteratorType::parallel);
3096  return iteratorTypes;
3097 }
3098 
3099 LogicalResult WinogradInputTransformOp::getResultTilePosition(
3100  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3101  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3102  SmallVector<OpFoldResult> &resultSizes) {
3103  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3104  ShapedType inputType = getInputOperandType();
3105  ArrayRef<int64_t> inputShape = inputType.getShape();
3106  int64_t inputH = inputShape[getInputHDim()];
3107  int64_t inputW = inputShape[getInputWDim()];
3108  int64_t m = getM();
3109  int64_t r = getR();
3110  int64_t alpha = m + r - 1;
3111  int64_t alphaH = inputH != 1 ? alpha : 1;
3112  int64_t alphaW = inputW != 1 ? alpha : 1;
3113  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3114  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3115 
3116  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3117  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3118  offsets[getOutputCDim()]});
3119  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3120  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3121  sizes[getOutputCDim()]});
3122 
3123  return success();
3124 }
3125 
3126 /// Implement tiling for winograd_input_transform
3127 /// The input of winograd_input_transform is (N, H, W, C).
3128 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3129 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3130 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3131 /// the values for the sizes of tileH, tileW, N, C for one tile.
3132 FailureOr<TilingResult>
3133 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3134  ArrayRef<OpFoldResult> offsets,
3135  ArrayRef<OpFoldResult> sizes) {
3136  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3137  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3138  ShapedType inputType = getInputOperandType();
3139  ArrayRef<int64_t> inputShape = inputType.getShape();
3140  int64_t inputH = inputShape[getInputHDim()];
3141  int64_t inputW = inputShape[getInputWDim()];
3142  int64_t m = getM();
3143  int64_t r = getR();
3144 
3145  Location loc = getLoc();
3146  MLIRContext *context = builder.getContext();
3147  auto offsetAffineMap =
3148  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3149  Value mappedOffsetH = affine::makeComposedAffineApply(
3150  builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
3151  Value mappedOffsetW = affine::makeComposedAffineApply(
3152  builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
3153  auto sizeAffineMap = AffineMap::get(
3154  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3155  Value mappedSizeH = affine::makeComposedAffineApply(
3156  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3157  Value mappedSizeW = affine::makeComposedAffineApply(
3158  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3159 
3160  SmallVector<Value> tiledOperands;
3161  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3162 
3163  OpFoldResult offsetH =
3164  inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3165  OpFoldResult offsetW =
3166  inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3167  sliceOffsets.append(
3168  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3169  OpFoldResult sizeH =
3170  inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3171  OpFoldResult sizeW =
3172  inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3173  sliceSizes.append(
3174  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3175  int64_t inputRank = getInputOperandRank();
3176  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3177  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3178  loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3179  tiledOperands.emplace_back(inputSlice);
3180 
3181  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3182  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3183  resultSizes)))
3184  return failure();
3185 
3186  int64_t outputRank = getOutputOperandRank();
3187  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3188  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3189  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3190  tiledOperands.emplace_back(outputSlice);
3191 
3192  SmallVector<Type> resultTypes;
3193  resultTypes.push_back(tiledOperands[1].getType());
3194  Operation *tiledOp =
3195  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3196 
3197  return TilingResult{
3198  {tiledOp},
3199  SmallVector<Value>(tiledOp->getResults()),
3200  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3201 }
3202 
3203 //===----------------------------------------------------------------------===//
3204 // WinogradOutputTransformOp
3205 //===----------------------------------------------------------------------===//
3206 
3207 LogicalResult WinogradOutputTransformOp::verify() {
3208  auto valueType = cast<ShapedType>(getValue().getType());
3209  ArrayRef<int64_t> valueShape = valueType.getShape();
3210  int64_t valueH = valueShape[getValueAlphaHDim()];
3211  int64_t valueW = valueShape[getValueAlphaWDim()];
3212  int64_t valueTileH = valueShape[getValueTileHDim()];
3213  int64_t valueTileW = valueShape[getValueTileWDim()];
3214  int m = getM();
3215  int r = getR();
3216  bool leftTransform = valueH != 1;
3217  bool rightTransform = valueW != 1;
3218 
3219  int64_t outputRank = getOutputOperandRank();
3220  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3221  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3222  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3223  } else {
3224  if (valueH != (leftTransform ? m + r - 1 : 1))
3225  return emitOpError("expect input height equals to input tile size");
3226  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3227  }
3228  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3229  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3230  } else {
3231  if (valueW != (rightTransform ? m + r - 1 : 1))
3232  return emitOpError("expect input width equals to input tile size");
3233  expectedOutputShape[getOutputWDim()] =
3234  (rightTransform ? m : 1) * valueTileW;
3235  }
3236  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3237  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3238 
3239  auto outputType = cast<ShapedType>(getOutput().getType());
3240  ArrayRef<int64_t> outputShape = outputType.getShape();
3241  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3242  return emitOpError("the output shape is not expected");
3243  }
3244  return success();
3245 }
3246 
3248 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3249  Location loc = getLoc();
3250  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3251  IntegerAttr oneAttr = builder.getIndexAttr(1);
3252  Value value = getValue();
3253  int64_t valueRank = getValueOperandRank();
3254  SmallVector<Range> loopBounds(valueRank);
3255  for (unsigned dim = 0; dim < valueRank; ++dim) {
3256  loopBounds[dim].offset = zeroAttr;
3257  // alphaH, alphaW, tileH, tileW, N, F
3258  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3259  loopBounds[dim].stride = oneAttr;
3260  }
3261  return loopBounds;
3262 }
3263 
3265 WinogradOutputTransformOp::getLoopIteratorTypes() {
3266  int64_t valueRank = getValueOperandRank();
3267  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3268  utils::IteratorType::parallel);
3269  return iteratorTypes;
3270 }
3271 
3272 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3273  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3274  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3275  SmallVector<OpFoldResult> &resultSizes) {
3276  int64_t m = getM();
3277 
3278  Location loc = getLoc();
3279  MLIRContext *context = builder.getContext();
3280  auto affineMap =
3281  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3282 
3283  Value mappedOffsetH = affine::makeComposedAffineApply(
3284  builder, loc, affineMap, offsets[getValueTileHDim()]);
3285  Value mappedOffsetW = affine::makeComposedAffineApply(
3286  builder, loc, affineMap, offsets[getValueTileWDim()]);
3287  Value mappedSizeH = affine::makeComposedAffineApply(
3288  builder, loc, affineMap, sizes[getValueTileHDim()]);
3289  Value mappedSizeW = affine::makeComposedAffineApply(
3290  builder, loc, affineMap, sizes[getValueTileWDim()]);
3291 
3292  ShapedType valueType = getValueOperandType();
3293  ArrayRef<int64_t> valueShape = valueType.getShape();
3294  int64_t valueH = valueShape[0];
3295  int64_t valueW = valueShape[1];
3296  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3297  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3298  OpFoldResult offsetH =
3299  valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3300  OpFoldResult offsetW =
3301  valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3302  OpFoldResult sizeH =
3303  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3304  OpFoldResult sizeW =
3305  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3306 
3307  resultOffsets.append(
3308  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3309  resultSizes.append(
3310  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3311  return success();
3312 }
3313 
3314 /// Implement tiling for winograd_output_transform
3315 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3316 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3317 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3318 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3319 /// for the sizes of tileH, tileW, N, F for one tile.
3320 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3321  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3322  ArrayRef<OpFoldResult> sizes) {
3323  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3324  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3325  Location loc = getLoc();
3326  SmallVector<Value> tiledOperands;
3327  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3328 
3329  ShapedType valueType = getValueOperandType();
3330  ArrayRef<int64_t> valueShape = valueType.getShape();
3331  int64_t alphaH = valueShape[getValueAlphaHDim()];
3332  int64_t alphaW = valueShape[getValueAlphaWDim()];
3333  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3334  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3335 
3336  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3337  offsets[getValueTileWDim()], offsets[getValueNDim()],
3338  offsets[getValueFDim()]});
3339  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3340  sizes[getValueTileWDim()], sizes[getValueNDim()],
3341  sizes[getValueFDim()]});
3342  int64_t valueRank = getValueOperandRank();
3343  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3344  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3345  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3346  tiledOperands.emplace_back(valueSlice);
3347 
3348  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3349  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3350  resultSizes)))
3351  return failure();
3352 
3353  int64_t outputRank = getOutputOperandRank();
3354  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3355  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3356  loc, getOutput(), resultOffsets, resultSizes, strides);
3357  tiledOperands.emplace_back(outputSlice);
3358 
3359  SmallVector<Type> resultTypes;
3360  resultTypes.push_back(tiledOperands[1].getType());
3361  Operation *tiledOp =
3362  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3363 
3364  return TilingResult{
3365  {tiledOp},
3366  SmallVector<Value>(tiledOp->getResults()),
3367  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3368 }
3369 
3370 //===----------------------------------------------------------------------===//
3371 // LinalgDialect
3372 //===----------------------------------------------------------------------===//
3373 
3374 void LinalgDialect::getCanonicalizationPatterns(
3375  RewritePatternSet &results) const {
3376  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3377  InferStaticShapeOfOperands>(getContext());
3378 }
3379 
3381  Attribute value, Type type,
3382  Location loc) {
3383  return arith::ConstantOp::materialize(builder, value, type, loc);
3384 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1814
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:272
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2745
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2816
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:120
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2284
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:260
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:331
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1647
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1695
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2790
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1447
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:156
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1316
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2767
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1207
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1174
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2197
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:298
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:949
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:291
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:52
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1370
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1476
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:324
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:186
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_iterator result_end()
Definition: Operation.h:409
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:115
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1142
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2589
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2276
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:99
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2317
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2256
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2267
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:90
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:347
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1957
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1960
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:1982
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1985
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.