MLIR  21.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
30 #include "mlir/IR/AffineMap.h"
31 #include "mlir/IR/Attributes.h"
34 #include "mlir/IR/Matchers.h"
37 #include "mlir/IR/PatternMatch.h"
40 
41 #include "llvm/ADT/DenseMap.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SetOperations.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/ADT/SmallVector.h"
46 #include "llvm/ADT/StringSet.h"
47 #include "llvm/ADT/TypeSwitch.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include "llvm/Support/LogicalResult.h"
50 #include "llvm/Support/MathExtras.h"
51 #include "llvm/Support/raw_ostream.h"
52 #include <cassert>
53 #include <optional>
54 
55 using namespace mlir;
56 using namespace mlir::linalg;
57 
58 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
60  int64_t dim) {
61  auto type = cast<ShapedType>(v.getType());
62  if (!type.isDynamicDim(dim))
63  return builder.getIndexAttr(type.getDimSize(dim));
64 
65  return getAsOpFoldResult(
67  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
68  return builder.create<tensor::DimOp>(loc, v, dim);
69  })
70  .Case<MemRefType>([&](MemRefType t) -> Value {
71  return builder.create<memref::DimOp>(loc, v, dim);
72  }));
73 }
74 
75 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
76 /// `source`.
77 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
78  ArrayRef<OpFoldResult> offsets,
80  ArrayRef<OpFoldResult> strides) {
81  return TypeSwitch<Type, Operation *>(source.getType())
82  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
83  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
84  strides);
85  })
86  .Case<MemRefType>([&](MemRefType type) -> Operation * {
87  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
88  strides);
89  })
90  .Default([&](Type t) -> Operation * { return nullptr; });
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Helper functions
95 //===----------------------------------------------------------------------===//
96 
98  int64_t dim) {
99  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
100  return b.createOrFold<memref::DimOp>(loc, source, dim);
101  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
102  return b.createOrFold<tensor::DimOp>(loc, source, dim);
103  llvm_unreachable("Expected MemRefType or TensorType");
104 }
105 
107  int64_t dim) {
108  auto shapedType = llvm::cast<ShapedType>(source.getType());
109  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
110  return createOrFoldDimOp(b, loc, source, dim);
111  return b.getIndexAttr(shapedType.getDimSize(dim));
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Support for named Linalg ops defined in ods-gen.
116 //===----------------------------------------------------------------------===//
117 
120 
121 /// Fills the region of a structured operation using the provided
122 /// `regionBuilder`. The method is used by both named structured ops created by
123 /// ods-gen and by manually defined C++ ops. It is called by both builders and
124 /// parsers and creates a block with arguments corresponding to the elemental
125 /// types of `inputTypes` and `outputTypes`.
126 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
127  TypeRange inputTypes, TypeRange outputTypes,
129  RegionBuilderFn regionBuilder) {
130  SmallVector<Type, 8> argTypes;
131  SmallVector<Location, 8> argLocs;
132  for (auto containers : {inputTypes, outputTypes}) {
133  for (auto t : containers) {
134  argTypes.push_back(
135  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
136 
137  // TODO: Pass in a proper location here.
138  argLocs.push_back(opBuilder.getUnknownLoc());
139  }
140  }
141 
142  // RAII.
143  OpBuilder::InsertionGuard guard(opBuilder);
144  Block *body =
145  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
146 
147  opBuilder.setInsertionPointToStart(body);
148  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
149  regionBuilder(b, *body, attrs);
150 
151  // indexing_maps is an auto-generated method.
152 
153  // iterator_types is an auto-generated method.
154 }
155 
156 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
157 /// The result types are derived automatically if `resultTensorTypes` is none.
158 /// The body of the operation is filled using `regionBuilder`. All ods-gen
159 /// created structured operations use the method to implement their builders.
161  std::optional<TypeRange> resultTensorTypes,
162  ValueRange inputs, ValueRange outputs,
163  ArrayRef<NamedAttribute> attributes,
164  RegionBuilderFn regionBuilder) {
165  // Derive the result types if needed.
166  SmallVector<Type> derivedResultTypes =
167  resultTensorTypes.value_or(TypeRange());
168  if (!resultTensorTypes)
169  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
170  llvm::IsaPred<RankedTensorType>);
171 
172  state.addOperands(inputs);
173  state.addOperands(outputs);
174  state.addTypes(derivedResultTypes);
175 
176  state.addAttributes(attributes);
177  state.addAttribute(
178  "operandSegmentSizes",
179  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
180  static_cast<int32_t>(outputs.size())}));
181 
182  // Create and fill the region of the structured operation.
183  Region &region = *state.addRegion();
184  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
185  state.attributes.getAttrs(), regionBuilder);
186 }
187 
188 static void buildMatmulOp(OpBuilder &b, OperationState &state,
189  std::optional<TypeRange> resultTensorTypes,
190  ValueRange inputs, ValueRange outputs,
191  ArrayRef<NamedAttribute> attributes,
192  RegionBuilderFn regionBuilder,
193  ArrayRef<AffineMap> indexingMaps) {
194  // Initialize indexingMaps attribute, for MatmulOp.
195  SmallVector<Attribute, 3> indexingMapsAttrVal;
196  indexingMapsAttrVal = llvm::map_to_vector(
197  MatmulOp::getDefaultIndexingMaps(b.getContext()),
198  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
199  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
200  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
201  attributes, regionBuilder);
202 }
203 
205  std::optional<TypeRange> resultTensorTypes,
206  ValueRange inputs, ValueRange outputs,
207  ArrayRef<NamedAttribute> attributes,
208  RegionBuilderFn regionBuilder,
209  ArrayRef<AffineMap> indexingMaps) {
210  // Initialize indexingMaps attribute, for BatchMatmulOp.
211  SmallVector<Attribute, 4> indexingMapsAttrVal;
212  indexingMapsAttrVal =
213  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
214  return AffineMapAttr::get(map);
215  });
216  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
217  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
218  attributes, regionBuilder);
219 }
220 
221 /// Common parsing used for both named structured ops created by ods-gen and by
222 /// manually defined C++ ops. Does not handle regions.
223 static ParseResult
225  SmallVectorImpl<Type> &inputTypes,
226  SmallVectorImpl<Type> &outputTypes,
227  bool addOperandSegmentSizes = true) {
228  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
230  outputsOperands;
231 
232  if (succeeded(parser.parseOptionalLess())) {
233  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
234  return failure();
235  }
236  attrsLoc = parser.getCurrentLocation();
237  if (parser.parseOptionalAttrDict(result.attributes))
238  return failure();
239 
240  if (succeeded(parser.parseOptionalKeyword("ins"))) {
241  if (parser.parseLParen())
242  return failure();
243 
244  inputsOperandsLoc = parser.getCurrentLocation();
245  if (parser.parseOperandList(inputsOperands) ||
246  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
247  return failure();
248  }
249 
250  if (succeeded(parser.parseOptionalKeyword("outs"))) {
251  outputsOperandsLoc = parser.getCurrentLocation();
252  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
253  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
254  return failure();
255  }
256 
257  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
258  result.operands) ||
259  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
260  result.operands))
261  return failure();
262 
263  if (addOperandSegmentSizes) {
264  // This is a bit complex because we're trying to be backward compatible with
265  // operation syntax that mix the inherent attributes and the discardable
266  // ones in the same dictionary. If the properties are used, we append the
267  // operandSegmentSizes there directly. Otherwise we append it to the
268  // discardable attributes dictionary where it is handled by the generic
269  // Operation::create(...) method.
270  if (result.propertiesAttr) {
271  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
272  attrs.append("operandSegmentSizes",
274  {static_cast<int32_t>(inputsOperands.size()),
275  static_cast<int32_t>(outputsOperands.size())}));
276  result.propertiesAttr = attrs.getDictionary(parser.getContext());
277  } else {
278  result.addAttribute("operandSegmentSizes",
280  {static_cast<int32_t>(inputsOperands.size()),
281  static_cast<int32_t>(outputsOperands.size())}));
282  }
283  }
284  if (!result.propertiesAttr) {
285  std::optional<RegisteredOperationName> info =
286  result.name.getRegisteredInfo();
287  if (info) {
288  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
289  return parser.emitError(attrsLoc)
290  << "'" << result.name.getStringRef() << "' op ";
291  })))
292  return failure();
293  }
294  }
295  return success();
296 }
297 
299  ValueRange outputs) {
300  if (!inputs.empty())
301  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
302  if (!outputs.empty())
303  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Specific parsing and printing for named structured ops created by ods-gen.
308 //===----------------------------------------------------------------------===//
309 
310 static ParseResult parseNamedStructuredOpRegion(
311  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
312  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
313  RegionBuilderFn regionBuilder) {
314  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
315  return parser.emitError(
316  parser.getCurrentLocation(),
317  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
318  "region expects {0} args, got {1}",
319  numRegionArgs, inputTypes.size() + outputTypes.size()));
320  }
321 
322  OpBuilder opBuilder(parser.getContext());
323  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
324  regionBuilder);
325  return success();
326 }
327 
328 static ParseResult
330  SmallVectorImpl<Type> &resultTypes) {
331  if (parser.parseOptionalArrowTypeList(resultTypes))
332  return failure();
333  return success();
334 }
335 
336 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
337  OperationState &result,
338  unsigned numRegionArgs,
339  RegionBuilderFn regionBuilder) {
340  // TODO: Enable when ods-gen supports captures.
341  SmallVector<Type, 1> inputTypes, outputTypes;
342  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
343  return failure();
344 
345  // Parse optional attributes.
346  if (parser.parseOptionalAttrDict(result.attributes))
347  return failure();
348 
349  // TODO: consider merging results parsing into region parsing.
350  // Need to wait for declarative assembly resolution to decide.
351  SmallVector<Type, 1> outputTensorsTypes;
352  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
353  return failure();
354  result.addTypes(outputTensorsTypes);
355 
356  std::unique_ptr<Region> region = std::make_unique<Region>();
357  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
358  outputTypes, result.attributes.getAttrs(),
359  regionBuilder))
360  return failure();
361  result.addRegion(std::move(region));
362 
363  return success();
364 }
365 
367  TypeRange resultTypes) {
368  if (resultTypes.empty())
369  return;
370  p.printOptionalArrowTypeList(resultTypes);
371 }
372 
374  ValueRange inputs, ValueRange outputs,
375  ArrayRef<StringRef> elidedAttrs = {}) {
376  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
377 
378  // Printing is shared with generic ops, except for the region and
379  // attributes.
380  printCommonStructuredOpParts(p, inputs, outputs);
381 
382  // Results printing.
384 
385  // Region is elided.
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // Region builder helper.
390 // TODO: Move this to a utility library.
391 // The public methods on this class are referenced directly from generated code.
392 // Helper build the unary, binary, and type conversion functions defined by the
393 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
394 // class.
395 //
396 // Implementations of the math functions must be polymorphic over numeric types,
397 // internally performing necessary casts. If the function application makes no
398 // sense, then the only recourse is to assert and return nullptr. This can be
399 // extended later if it becomes possible to fail construction of the region. The
400 // invariant should be enforced at a higher level.
401 //
402 // TODO: These helpers are currently type polymorphic over the class of integer
403 // and floating point types, but they will not internally cast within bit
404 // widths of a class (mixed precision such as i8->i32) or across classes
405 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
406 // to be handled with care and work is being considered to extend the op
407 // language to make such cases explicit. In the mean-time, violating this will
408 // fail verification, which is deemed acceptable.
409 //===----------------------------------------------------------------------===//
410 
411 namespace {
412 
413 class RegionBuilderHelper {
414 public:
415  RegionBuilderHelper(OpBuilder &builder, Block &block)
416  : builder(builder), block(block) {}
417 
418  // Build the unary functions defined by OpDSL.
419  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
420  if (!isFloatingPoint(arg))
421  llvm_unreachable("unsupported non numeric type");
422  OpBuilder::InsertionGuard g(builder);
423  builder.setInsertionPointToEnd(&block);
424  switch (unaryFn) {
425  case UnaryFn::exp:
426  return builder.create<math::ExpOp>(arg.getLoc(), arg);
427  case UnaryFn::log:
428  return builder.create<math::LogOp>(arg.getLoc(), arg);
429  case UnaryFn::abs:
430  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
431  case UnaryFn::ceil:
432  return builder.create<math::CeilOp>(arg.getLoc(), arg);
433  case UnaryFn::floor:
434  return builder.create<math::FloorOp>(arg.getLoc(), arg);
435  case UnaryFn::negf:
436  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
437  case UnaryFn::reciprocal: {
438  Attribute oneAttr = builder.getOneAttr(arg.getType());
439  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
440  ::cast<TypedAttr>(oneAttr));
441  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
442  }
443  case UnaryFn::round:
444  return builder.create<math::RoundOp>(arg.getLoc(), arg);
445  case UnaryFn::sqrt:
446  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
447  case UnaryFn::rsqrt:
448  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
449  case UnaryFn::square:
450  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
451  case UnaryFn::tanh:
452  return builder.create<math::TanhOp>(arg.getLoc(), arg);
453  case UnaryFn::erf:
454  return builder.create<math::ErfOp>(arg.getLoc(), arg);
455  }
456  llvm_unreachable("unsupported unary function");
457  }
458 
459  // Build the binary functions defined by OpDSL.
460  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
461  bool allComplex = isComplex(arg0) && isComplex(arg1);
462  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
463  bool allInteger = isInteger(arg0) && isInteger(arg1);
464  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
465  arg1.getType().getIntOrFloatBitWidth() == 1;
466  if (!allComplex && !allFloatingPoint && !allInteger)
467  llvm_unreachable("unsupported non numeric type");
468  OpBuilder::InsertionGuard g(builder);
469  builder.setInsertionPointToEnd(&block);
470  switch (binaryFn) {
471  case BinaryFn::add:
472  if (allComplex)
473  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
474  if (allFloatingPoint)
475  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
476  if (allBool)
477  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
478  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
479  case BinaryFn::sub:
480  if (allComplex)
481  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
482  if (allFloatingPoint)
483  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
484  if (allBool)
485  llvm_unreachable("unsupported operation: sub with bools");
486  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
487  case BinaryFn::mul:
488  if (allComplex)
489  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
490  if (allFloatingPoint)
491  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
492  if (allBool)
493  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
494  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
495  case BinaryFn::div:
496  if (allComplex)
497  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
498  if (allFloatingPoint)
499  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
500  if (allBool)
501  llvm_unreachable("unsupported operation: div with bools");
502  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
503  case BinaryFn::div_unsigned:
504  if (!allInteger || allBool)
505  llvm_unreachable("unsupported operation: unsigned div not on uint");
506  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
507  case BinaryFn::max_signed:
508  assert(!allComplex);
509  if (allFloatingPoint)
510  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
511  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
512  case BinaryFn::min_signed:
513  assert(!allComplex);
514  if (allFloatingPoint)
515  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
516  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
517  case BinaryFn::max_unsigned:
518  assert(!allComplex);
519  if (allFloatingPoint)
520  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
521  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
522  case BinaryFn::min_unsigned:
523  assert(!allComplex);
524  if (allFloatingPoint)
525  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
526  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
527  case BinaryFn::powf:
528  assert(allFloatingPoint);
529  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
530  }
531  llvm_unreachable("unsupported binary function");
532  }
533 
534  // Build the ternary functions defined by OpDSL.
535  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
536  Value arg2) {
537  bool headBool =
538  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
539  bool tailFloatingPoint =
540  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
541  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
542  OpBuilder::InsertionGuard g(builder);
543  builder.setInsertionPointToEnd(&block);
544  switch (ternaryFn) {
545  case TernaryFn::select:
546  if (!headBool && !(tailFloatingPoint || tailInteger))
547  llvm_unreachable("unsupported non numeric type");
548  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
549  }
550  llvm_unreachable("unsupported ternary function");
551  }
552 
553  // Build the type functions defined by OpDSL.
554  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
555  switch (typeFn) {
556  case TypeFn::cast_signed:
557  return cast(toType, operand, false);
558  case TypeFn::cast_unsigned:
559  return cast(toType, operand, true);
560  }
561  llvm_unreachable("unsupported type conversion function");
562  }
563 
564  void yieldOutputs(ValueRange values) {
565  OpBuilder::InsertionGuard g(builder);
566  builder.setInsertionPointToEnd(&block);
567  Location loc = builder.getUnknownLoc();
568  builder.create<YieldOp>(loc, values);
569  }
570 
571  Value constant(const std::string &value) {
572  OpBuilder::InsertionGuard g(builder);
573  builder.setInsertionPointToEnd(&block);
574  Location loc = builder.getUnknownLoc();
575  Attribute valueAttr = parseAttribute(value, builder.getContext());
576  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
577  }
578 
579  Value index(int64_t dim) {
580  OpBuilder::InsertionGuard g(builder);
581  builder.setInsertionPointToEnd(&block);
582  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
583  }
584 
585  Type getIntegerType(unsigned width) {
586  return IntegerType::get(builder.getContext(), width);
587  }
588 
589  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
590  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
591 
592 private:
593  // Generates operations to cast the given operand to a specified type.
594  // If the cast cannot be performed, a warning will be issued and the
595  // operand returned as-is (which will presumably yield a verification
596  // issue downstream).
597  Value cast(Type toType, Value operand, bool isUnsignedCast) {
598  OpBuilder::InsertionGuard g(builder);
599  builder.setInsertionPointToEnd(&block);
600  auto loc = operand.getLoc();
601  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
602  }
603 
604  bool isComplex(Value value) {
605  return llvm::isa<ComplexType>(value.getType());
606  }
607  bool isFloatingPoint(Value value) {
608  return llvm::isa<FloatType>(value.getType());
609  }
610  bool isInteger(Value value) {
611  return llvm::isa<IntegerType>(value.getType());
612  }
613 
614  OpBuilder &builder;
615  Block &block;
616 };
617 
618 } // namespace
619 
620 //===----------------------------------------------------------------------===//
621 // CopyOp
622 //===----------------------------------------------------------------------===//
623 
624 namespace {
625 
626 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
628  LogicalResult matchAndRewrite(CopyOp copyOp,
629  PatternRewriter &rewriter) const override {
630  if (copyOp.getInputs() != copyOp.getOutputs())
631  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
632  if (copyOp.hasPureBufferSemantics())
633  rewriter.eraseOp(copyOp);
634  else
635  rewriter.replaceOp(copyOp, copyOp.getInputs());
636 
637  return success();
638  }
639 };
640 
641 } // namespace
642 
643 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
644  MLIRContext *context) {
645  results.add<EraseSelfCopy>(context);
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // FillOp
650 //===----------------------------------------------------------------------===//
651 
652 namespace {
653 
654 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
655 ///
656 /// For such op chains, we can create new linalg.fill ops with the result
657 /// type of the tensor.expand/collapse_shape op.
658 template <typename TensorReshapeOp>
659 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
661  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
662  PatternRewriter &rewriter) const override {
663  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
664  if (!oldFill)
665  return failure();
666 
667  Location loc = oldFill.getLoc();
668  TensorReshapeOp newInit;
669  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
670 
671  newInit = rewriter.create<TensorReshapeOp>(
672  loc, reshapeOp.getResultType(), oldFill.output(),
673  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
674  reshapeOp.getStaticOutputShape());
675  } else {
676  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
677  oldFill.output(),
678  reshapeOp.getReassociation());
679  }
680  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
681  ValueRange{newInit});
682  return success();
683  }
684 };
685 
686 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
687 /// filling value are the same.
688 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
690 
691  LogicalResult matchAndRewrite(tensor::PadOp padOp,
692  PatternRewriter &rewriter) const override {
693  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
694  if (!fillOp)
695  return failure();
696 
697  // We can only fold if the padding value is the same as the original
698  // filling value.
699  Value padValue = padOp.getConstantPaddingValue();
700  if (!padValue || fillOp.value() != padValue)
701  return failure();
702 
703  ReifiedRankedShapedTypeDims reifiedShape;
704  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
705  return rewriter.notifyMatchFailure(
706  padOp, "failed to reify tensor.pad op result shape");
707 
708  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
709  padOp.getLoc(), reifiedShape.front(),
710  padOp.getResultType().getElementType());
711  Value replacement =
712  rewriter
713  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
714  ValueRange{emptyTensor})
715  .getResult(0);
716  if (replacement.getType() != padOp.getResultType()) {
717  replacement = rewriter.create<tensor::CastOp>(
718  fillOp.getLoc(), padOp.getResultType(), replacement);
719  }
720  rewriter.replaceOp(padOp, replacement);
721  return success();
722  }
723 };
724 
725 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
726 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
727 /// filling value are the same.
728 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
730 
731  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
732  PatternRewriter &rewriter) const override {
733  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
734  if (!srcPadOp)
735  return failure();
736 
737  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
738  return failure();
739 
740  // Walk back the tensor.insert_slice chain and find the first destination
741  // value at the start of the chain.
742  Value firstDest = insertOp.getDest();
743  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
744  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
745  return failure();
746 
747  // Make sure the range of values accessed are disjoint. Without this, we
748  // cannot fold tensor.pad away.
749  bool disjoint = false;
750  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
751  // If the dimension has dynamic offset/size, we cannot guarantee
752  // disjoint. So just skip it.
753  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
754  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
755  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
756  continue;
757 
758  // Get the range start and end, inclusively for both.
759  int64_t prevStart = prevOp.getStaticOffset(i);
760  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
761  prevOp.getStaticStride(i);
762  int64_t nextStart = insertOp.getStaticOffset(i);
763  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
764  insertOp.getStaticStride(i);
765  if (prevEnd < nextStart || nextEnd < prevStart) {
766  disjoint = true;
767  break;
768  }
769  }
770 
771  if (!disjoint)
772  break;
773  firstDest = prevOp.getDest();
774  }
775 
776  // Check whether the first destination is a fill op. For overlapped cases,
777  // this also cannot be true.
778  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
779  if (!dstFillOp)
780  return failure();
781 
782  // We can only fold if the padding value is the same as the original
783  // filling value.
784  Value padValue = srcPadOp.getConstantPaddingValue();
785  if (!padValue || dstFillOp.value() != padValue)
786  return failure();
787 
788  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
789  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
790 
791  Location loc = insertOp.getLoc();
792  MLIRContext *context = getContext();
793 
794  AffineExpr sym0, sym1;
795  bindSymbols(context, sym0, sym1);
796  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
797 
798  // Calculate the new offsets for the insert. It should be the old offsets
799  // plus low padding sizes.
800  SmallVector<OpFoldResult, 4> newOffsets;
801  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
802  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
803  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
804  }
805 
806  RankedTensorType srcPadType = srcPadOp.getSourceType();
808  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
809  if (srcPadType.isDynamicDim(i)) {
810  newSizes.push_back(
811  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
812  .getResult());
813  } else {
814  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
815  }
816  }
817 
818  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
819  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
820  newSizes, insertOp.getMixedStrides());
821  return success();
822  }
823 };
824 
825 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
826 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
827 public:
829 
830  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
831  PatternRewriter &rewriter) const override {
832  // See if tensor input of tensor.extract op is the result of a linalg.fill
833  // op.
834  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
835  if (!fillOp)
836  return failure();
837 
838  // Get scalar input operand of linalg.fill op.
839  Value extractedScalar = fillOp.getInputs()[0];
840 
841  // Replace tensor.extract op with scalar value used to fill the tensor.
842  rewriter.replaceOp(extractOp, extractedScalar);
843  return success();
844  }
845 };
846 
847 /// Folds pack(fill) into a single fill op if
848 /// 1. The pack op does not have padding value, or
849 /// 2. The filled value and padding value are the same.
850 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
851  linalg::PackOp packOp) {
852  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
853  if (!fillOp)
854  return failure();
855 
856  if (auto paddingValue = packOp.getPaddingValue())
857  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
858  return failure();
859 
860  Value packOpDest = packOp.getDest();
861  if (!packOpDest.hasOneUse())
862  return failure();
863 
864  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
865  packOp.getDest());
866 }
867 
868 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
869 struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
870 public:
871  FoldFillWithPack(MLIRContext *context)
872  : OpRewritePattern<linalg::PackOp>(context) {}
873 
874  LogicalResult matchAndRewrite(linalg::PackOp packOp,
875  PatternRewriter &rewriter) const override {
876  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
877  if (failed(fillOp))
878  return failure();
879  rewriter.replaceOp(packOp, fillOp.value().result());
880  return success();
881  }
882 };
883 
884 /// Fold fill with copy.
885 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
887 
888  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
889  PatternRewriter &rewriter) const override {
890  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
891  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
892  fillOp.getInputs(),
893  copyOp.getOutputs());
894  return success();
895  }
896  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
897  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
898  fillOp.getOutputs());
899  return success();
900  }
901  return failure();
902  }
903 };
904 
905 /// Fold fill with transpose.
906 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
908 
909  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
910  PatternRewriter &rewriter) const override {
911  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
912  rewriter.replaceOpWithNewOp<FillOp>(
913  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
914  transposeOp.getDpsInitOperand(0)->get());
915  return success();
916  }
917  return failure();
918  }
919 };
920 
921 /// Fold a concat with all elements being fills of the same value
922 /// into a fill of the concat result shape.
923 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
925 
926  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
927  PatternRewriter &rewriter) const override {
928  auto concatOperands = concatOp.getInputs();
929  if (concatOperands.empty()) {
930  return failure();
931  }
932 
933  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
934  if (!firstFillOp) {
935  return failure();
936  }
937  // Prefetch the fill value.
938  OpFoldResult firstFillVal =
939  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
940  // Collect all the outs values for the fill operations.
941  SmallVector<Value> allOuts;
942  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
943 
944  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
945  auto fillOp = v.getDefiningOp<linalg::FillOp>();
946  if (!fillOp) {
947  return false;
948  }
949 
950  OpFoldResult fillVal =
951  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
952  if (fillVal != firstFillVal)
953  return false;
954 
955  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
956  return true;
957  };
958  if (!llvm::all_of(concatOperands.drop_front(),
959  isDefinedByCompatibleFillOp)) {
960  return rewriter.notifyMatchFailure(
961  concatOp, "not all operands are defined by a compatible fill op");
962  }
963 
964  Value outsConcat = rewriter.create<tensor::ConcatOp>(
965  concatOp.getLoc(), concatOp.getDim(), allOuts);
966  rewriter.replaceOpWithNewOp<linalg::FillOp>(
967  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
968  return success();
969  }
970 };
971 
972 } // namespace
973 
974 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
975  MLIRContext *context) {
976  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
977  FoldFillWithPack, FoldFillWithPad,
978  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
979  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
980  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
981 }
982 
983 //===----------------------------------------------------------------------===//
984 // GenericOp
985 //===----------------------------------------------------------------------===//
986 
987 static void buildGenericRegion(
988  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
989  ValueRange outputs,
990  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
991  SmallVector<Type, 4> blockArgTypes;
992  SmallVector<Location, 4> blockArgLocs;
993  for (ValueRange container : {inputs, outputs}) {
994  for (Value v : container) {
995  Type t = v.getType();
996  blockArgTypes.push_back(
997  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
998  blockArgLocs.push_back(v.getLoc());
999  }
1000  }
1001 
1002  OpBuilder::InsertionGuard guard(builder);
1003  Block *bodyBlock =
1004  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1005  bodyBuild(builder, loc, bodyBlock->getArguments());
1006 }
1007 
1008 void GenericOp::getAsmBlockArgumentNames(Region &region,
1009  OpAsmSetValueNameFn setNameFn) {
1010  for (Value v : getRegionInputArgs())
1011  setNameFn(v, "in");
1012  for (Value v : getRegionOutputArgs())
1013  setNameFn(v, "out");
1014 }
1015 
1016 void GenericOp::build(
1017  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1018  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1019  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1020  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1021  ArrayRef<NamedAttribute> attributes) {
1022  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1023  iteratorTypes, doc, libraryCall);
1024  result.addAttributes(attributes);
1025  if (bodyBuild)
1026  buildGenericRegion(builder, result.location, *result.regions.front(),
1027  inputs, outputs, bodyBuild);
1028 }
1029 
1030 void GenericOp::build(
1031  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1032  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1033  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1034  StringRef libraryCall,
1035  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1036  ArrayRef<NamedAttribute> attributes) {
1037  build(builder, result, resultTensorTypes, inputs, outputs,
1038  builder.getAffineMapArrayAttr(indexingMaps),
1039  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1040  iteratorTypes,
1041  [&](utils::IteratorType iter) -> mlir::Attribute {
1042  return IteratorTypeAttr::get(builder.getContext(), iter);
1043  }))),
1044  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1045  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1046  bodyBuild, attributes);
1047 }
1048 
1049 void GenericOp::build(
1050  OpBuilder &builder, OperationState &result, ValueRange inputs,
1051  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1052  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1053  StringRef libraryCall,
1054  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1055  ArrayRef<NamedAttribute> attributes) {
1056  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1057  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1058 }
1059 
1060 void GenericOp::build(
1061  OpBuilder &builder, OperationState &result, ValueRange inputs,
1062  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1063  ArrayRef<utils::IteratorType> iteratorTypes,
1064  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1065  ArrayRef<NamedAttribute> attributes) {
1066  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1067  /*doc=*/"",
1068  /*libraryCall=*/"", bodyBuild, attributes);
1069 }
1070 
1071 void GenericOp::build(
1072  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1073  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1074  ArrayRef<utils::IteratorType> iteratorTypes,
1075  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1076  ArrayRef<NamedAttribute> attributes) {
1077  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1078  iteratorTypes,
1079  /*doc=*/"",
1080  /*libraryCall=*/"", bodyBuild, attributes);
1081 }
1082 
1083 void GenericOp::print(OpAsmPrinter &p) {
1084  p << " ";
1085 
1086  // Print extra attributes.
1087  auto genericAttrNames = linalgTraitAttrNames();
1088 
1089  llvm::StringSet<> genericAttrNamesSet;
1090  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1091  SmallVector<NamedAttribute, 8> genericAttrs;
1092  for (auto attr : (*this)->getAttrs()) {
1093  if (attr.getName() == getIteratorTypesAttrName()) {
1094  auto iteratorTypes =
1095  llvm::cast<ArrayAttr>(attr.getValue())
1096  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1097  // Convert IteratorType enums into the string representation. This is
1098  // needed, because tests still use the old format when 'iterator_types'
1099  // attribute is represented as an array of strings.
1100  // TODO: Remove this conversion once tests are fixed.
1101  SmallVector<Attribute> iteratorTypeNames =
1102  llvm::to_vector(llvm::map_range(
1103  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1104  return StringAttr::get(getContext(), stringifyIteratorType(t));
1105  }));
1106 
1107  genericAttrs.emplace_back(
1108  getIteratorTypesAttrName(),
1109  ArrayAttr::get(getContext(), iteratorTypeNames));
1110  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1111  genericAttrs.push_back(attr);
1112  }
1113  }
1114  if (!genericAttrs.empty()) {
1115  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1116  p << genericDictAttr;
1117  }
1118 
1119  // Printing is shared with named ops, except for the region and attributes
1120  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1121 
1122  genericAttrNames.push_back("operandSegmentSizes");
1123  genericAttrNamesSet.insert(genericAttrNames.back());
1124 
1125  bool hasExtraAttrs = false;
1126  for (NamedAttribute n : (*this)->getAttrs()) {
1127  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1128  break;
1129  }
1130  if (hasExtraAttrs) {
1131  p << " attrs = ";
1132  p.printOptionalAttrDict((*this)->getAttrs(),
1133  /*elidedAttrs=*/genericAttrNames);
1134  }
1135 
1136  // Print region.
1137  if (!getRegion().empty()) {
1138  p << ' ';
1139  p.printRegion(getRegion());
1140  }
1141 
1142  // Print results.
1143  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1144 }
1145 
1146 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1147  DictionaryAttr dictAttr;
1148  // Parse the core linalg traits that must check into a dictAttr.
1149  // The name is unimportant as we will overwrite result.attributes.
1150  // The core linalg traits must contain the information necessary to pass the
1151  // verifier.
1152  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1153  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1154  return failure();
1155  result.attributes.assign(dictAttr.getValue().begin(),
1156  dictAttr.getValue().end());
1157 
1158  // Convert array of string into an array of IteratorType enums. This is
1159  // needed, because tests still use the old format when 'iterator_types'
1160  // attribute is represented as an array of strings.
1161  // TODO: Remove this conversion once tests are fixed.
1162  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1163  result.attributes.get(getIteratorTypesAttrName(result.name)));
1164  if (!iteratorTypes) {
1165  return parser.emitError(attributeLocation)
1166  << "expected " << getIteratorTypesAttrName(result.name)
1167  << " array attribute";
1168  }
1169 
1170  SmallVector<Attribute> iteratorTypeAttrs;
1171 
1172  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1173  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1174  if (!maybeIteratorType.has_value())
1175  return parser.emitError(parser.getCurrentLocation())
1176  << "unexpected iterator_type (" << s << ")";
1177 
1178  iteratorTypeAttrs.push_back(
1179  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1180  }
1181  result.attributes.set(getIteratorTypesAttrName(result.name),
1182  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1183 
1184  // Parsing is shared with named ops, except for the region.
1185  SmallVector<Type, 1> inputTypes, outputTypes;
1186  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1187  return failure();
1188 
1189  // Optional attributes may be added.
1190  if (succeeded(parser.parseOptionalKeyword("attrs")))
1191  if (failed(parser.parseEqual()) ||
1192  failed(parser.parseOptionalAttrDict(result.attributes)))
1193  return failure();
1194 
1195  std::unique_ptr<Region> region = std::make_unique<Region>();
1196  if (parser.parseRegion(*region, {}))
1197  return failure();
1198  result.addRegion(std::move(region));
1199 
1200  // Generic ops may specify that a subset of its outputs are tensors. Such
1201  // outputs are specified in the result type.
1202  // TODO: may need to move output parsing before region parsing.
1203  // Need to wait for declarative assembly resolution to decide.
1204  SmallVector<Type, 1> outputTensorsTypes;
1205  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1206  return failure();
1207  result.addTypes(outputTensorsTypes);
1208 
1209  return success();
1210 }
1211 
1214  &effects,
1215  LinalgOp linalgOp) {
1216  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1217  if (!llvm::isa<MemRefType>(operand.getType()))
1218  continue;
1219  effects.emplace_back(
1220  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1221  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1222  }
1223 
1224  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1225  if (!llvm::isa<MemRefType>(operand.get().getType()))
1226  continue;
1227  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1228  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1229  /*effectOnFullRegion=*/true,
1231  }
1232  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1233  /*effectOnFullRegion=*/true,
1235  }
1236 }
1237 
1238 void GenericOp::getEffects(
1240  &effects) {
1241  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1242 }
1243 
1245 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1246  // Operands with value semantics are speculatable, while operands with memory
1247  // semantics are not.
1248  if (!linalgOp.hasPureTensorSemantics())
1250  // The body of the op can still have speculation in its region.
1252 }
1253 
1254 Speculation::Speculatability GenericOp::getSpeculatability() {
1255  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1256 }
1257 
1258 LogicalResult GenericOp::verify() { return success(); }
1259 
1260 namespace {
1261 
1262 /// Remove any linalg operation (on tensors) that are just copying
1263 /// the values from inputs to the results. Requirements are
1264 /// 1) All iterator types are parallel
1265 /// 2) The body contains just a yield operation with the yielded values being
1266 /// the arguments corresponding to the operands.
1267 template <typename OpTy>
1268 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1270 
1271  LogicalResult matchAndRewrite(OpTy linalgOp,
1272  PatternRewriter &rewriter) const override {
1273  // All indexing maps must be equal. It follows that they are permutations.
1274  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1275  return failure();
1276 
1277  // Check that the body of the linalg operation is just a linalg.yield
1278  // operation.
1279  Block &body = linalgOp->getRegion(0).front();
1280  if (!llvm::hasSingleElement(body))
1281  return failure();
1282  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1283  if (!yieldOp)
1284  return failure();
1285 
1286  // In the buffer case, we need to check exact buffer equality.
1287  if (linalgOp.hasPureBufferSemantics()) {
1288  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1289  linalgOp.getDpsInputOperand(0)->get() ==
1290  linalgOp.getDpsInitOperand(0)->get()) {
1291  rewriter.eraseOp(linalgOp);
1292  return success();
1293  }
1294  return failure();
1295  }
1296 
1297  // Mixed semantics is not supported yet.
1298  if (!linalgOp.hasPureTensorSemantics())
1299  return failure();
1300 
1301  // Get the argument number of the returned values. That is the operand
1302  // number to use for replacing uses of this operation.
1303  SmallVector<Value> returnedArgs;
1304  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1305  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1306  if (!yieldArg || yieldArg.getOwner() != &body)
1307  return failure();
1308  unsigned argumentNumber = yieldArg.getArgNumber();
1309  Value returnedArg = linalgOp->getOperand(argumentNumber);
1310  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1311  // The input can have a different type than the result, e.g. a dynamic
1312  // input dimension can be turned into a static output dimension.
1313  Type returnType = returnedArg.getType();
1314  if (returnType != resultType) {
1315  // Distinguish between sparse conversion or dense tensor casting.
1316  // TODO: unify the two ops?
1317  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1319  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1320  linalgOp.getLoc(), resultType, returnedArg);
1321  else {
1322  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1323  resultType))
1324  return failure();
1325  returnedArg = rewriter.create<tensor::CastOp>(
1326  linalgOp.getLoc(), resultType, returnedArg);
1327  }
1328  }
1329  returnedArgs.push_back(returnedArg);
1330  }
1331 
1332  if (returnedArgs.size() != linalgOp->getNumResults())
1333  return failure();
1334  rewriter.replaceOp(linalgOp, returnedArgs);
1335  return success();
1336  }
1337 };
1338 
1339 } // namespace
1340 
1341 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1342  MLIRContext *context) {
1343  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1344 }
1345 
1346 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1347  return memref::foldMemRefCast(*this);
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // MapOp
1352 //===----------------------------------------------------------------------===//
1353 
1354 static ParseResult parseDstStyleOp(
1355  OpAsmParser &parser, OperationState &result,
1356  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1357  nullptr) {
1358  // Parse `ins` and `outs`.
1359  SmallVector<Type, 4> inputTypes, outputTypes;
1360  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1361  /*addOperandSegmentSizes=*/false))
1362  return failure();
1363 
1364  // Add result types.
1365  for (Type outputType : outputTypes) {
1366  if (llvm::isa<RankedTensorType>(outputType))
1367  result.addTypes(outputType);
1368  }
1369 
1370  // Parse required attributes.
1371  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1372  return failure();
1373 
1374  // Parse optional attributes.
1375  if (parser.parseOptionalAttrDict(result.attributes))
1376  return failure();
1377  return success();
1378 }
1379 
1380 void MapOp::getAsmBlockArgumentNames(Region &region,
1381  OpAsmSetValueNameFn setNameFn) {
1382  for (Value v : getRegionInputArgs())
1383  setNameFn(v, "in");
1384 }
1385 
1386 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1387  if (!getResults().empty())
1388  setNameFn(getResults().front(), "mapped");
1389 }
1390 
1391 void MapOp::build(
1392  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1393  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1394  ArrayRef<NamedAttribute> attributes) {
1395  build(builder, result, TypeRange{}, inputs, init);
1396  result.addAttributes(attributes);
1397 
1398  // Add output types for `RankedTensorType` output arguments.
1399  Type initType = init.getType();
1400  if (llvm::isa<RankedTensorType>(initType))
1401  result.addTypes(initType);
1402 
1403  if (bodyBuild)
1404  buildGenericRegion(builder, result.location, *result.regions.front(),
1405  inputs, /*outputs=*/{}, bodyBuild);
1406 }
1407 
1408 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1409  const OperationName &payloadOpName,
1410  const NamedAttrList &payloadOpAttrs,
1411  ArrayRef<Value> operands,
1412  bool initFirst = false) {
1413  OpBuilder b(parser.getContext());
1414  Region *body = result.addRegion();
1415  Block &block = body->emplaceBlock();
1416  b.setInsertionPointToStart(&block);
1417  SmallVector<Value> bbArgs;
1418  for (auto &operand : operands) {
1419  block.addArgument(
1420  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1421  b.getUnknownLoc());
1422  }
1423  SmallVector<Value> payloadOpOperands;
1424  // If initFirst flag is enabled, we consider init as the first position of
1425  // payload operands.
1426  if (initFirst) {
1427  payloadOpOperands.push_back(block.getArguments().back());
1428  for (const auto &arg : block.getArguments().drop_back())
1429  payloadOpOperands.push_back(arg);
1430  } else {
1431  payloadOpOperands = {block.getArguments().begin(),
1432  block.getArguments().end()};
1433  }
1434 
1435  Operation *payloadOp = b.create(
1436  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1437  payloadOpOperands,
1438  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1439  .getElementType()},
1440  payloadOpAttrs);
1441  b.create<YieldOp>(result.location, payloadOp->getResults());
1442 }
1443 
1444 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1445  std::optional<OperationName> payloadOpName;
1446  NamedAttrList payloadOpAttrs;
1447  if (succeeded(parser.parseOptionalLBrace())) {
1448  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1449  if (failed(operationName))
1450  return failure();
1451  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1452  return failure();
1453  payloadOpName = operationName.value();
1454  if (parser.parseRBrace())
1455  return failure();
1456  }
1457 
1458  if (parseDstStyleOp(parser, result))
1459  return failure();
1460 
1461  if (payloadOpName.has_value()) {
1462  if (!result.operands.empty())
1463  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1464  payloadOpAttrs,
1465  ArrayRef(result.operands).drop_back());
1466  else
1467  result.addRegion();
1468  } else {
1470  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1471  /*allowType=*/true, /*allowAttrs=*/true)) {
1472  return failure();
1473  }
1474  Region *body = result.addRegion();
1475  if (parser.parseRegion(*body, regionArgs))
1476  return failure();
1477  }
1478  return success();
1479 }
1480 
1481 // Retrieve the operation from the body, if it is the only one (except
1482 // yield) and if it gets the same amount of arguments as the body does.
1483 // If initFirst flag is enabled, we check that init takes the first position in
1484 // operands of payload.
1485 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1486  if (body->getOperations().size() != 2)
1487  return nullptr;
1488  Operation &payload = body->getOperations().front();
1489  assert(isa<YieldOp>(body->getOperations().back()));
1490 
1491  if (payload.getNumOperands() == 0 ||
1492  payload.getNumOperands() != body->getNumArguments())
1493  return nullptr;
1494  if (initFirst) {
1495  // check init
1496  if (payload.getOperands().back() != body->getArgument(0))
1497  return nullptr;
1498  // check rest
1499  for (const auto &[operand, bbArg] :
1500  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1501  if (bbArg != operand)
1502  return nullptr;
1503  }
1504  } else {
1505  for (const auto &[operand, bbArg] :
1506  llvm::zip(payload.getOperands(), body->getArguments())) {
1507  if (bbArg != operand)
1508  return nullptr;
1509  }
1510  }
1511  return &payload;
1512 }
1513 
1514 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1515  SmallVector<StringRef> elidedAttrs;
1516  std::string attrToElide;
1517  p << " { " << payloadOp->getName().getStringRef();
1518  for (const auto &attr : payloadOp->getAttrs()) {
1519  auto fastAttr =
1520  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1521  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1522  attrToElide = attr.getName().str();
1523  elidedAttrs.push_back(attrToElide);
1524  break;
1525  }
1526  }
1527  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1528  p << " }";
1529 }
1530 
1531 void MapOp::print(OpAsmPrinter &p) {
1532  Block *mapper = getBody();
1533  Operation *payloadOp = findPayloadOp(mapper);
1534  if (payloadOp) {
1535  printShortForm(p, payloadOp);
1536  }
1537 
1538  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1539  p.printOptionalAttrDict((*this)->getAttrs());
1540 
1541  if (!payloadOp) {
1542  // Print region if the payload op was not detected.
1543  p.increaseIndent();
1544  p.printNewline();
1545  p << "(";
1546  llvm::interleaveComma(mapper->getArguments(), p,
1547  [&](auto arg) { p.printRegionArgument(arg); });
1548  p << ") ";
1549 
1550  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1551  p.decreaseIndent();
1552  }
1553 }
1554 
1555 LogicalResult MapOp::verify() {
1556  auto *bodyBlock = getBody();
1557  auto blockArgs = bodyBlock->getArguments();
1558 
1559  // Checks if the number of `inputs` match the arity of the `mapper` region.
1560  if (getInputs().size() != blockArgs.size())
1561  return emitOpError() << "expects number of operands to match the arity of "
1562  "mapper, but got: "
1563  << getInputs().size() << " and " << blockArgs.size();
1564 
1565  // The parameters of mapper should all match the element type of inputs.
1566  for (const auto &[bbArgType, inputArg] :
1567  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1568  auto inputElemType =
1569  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1570  if (bbArgType != inputElemType) {
1571  return emitOpError() << "expected element type of input " << inputElemType
1572  << " to match bbArg type " << bbArgType;
1573  }
1574  }
1575 
1576  // The shape of each input must match the shape of the output.
1577  auto outputShape = getInit().getType().getShape();
1578  for (Type inputArgType : TypeRange{getInputs()}) {
1579  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1580  if (inputElemShape != outputShape) {
1581  return emitOpError() << "expected shape of input (" << inputElemShape
1582  << ") to match shape of output (" << outputShape
1583  << ")";
1584  }
1585  }
1586 
1587  return success();
1588 }
1589 
1590 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1591  int64_t rank = getInit().getType().getRank();
1592  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1593 }
1594 
1595 ArrayAttr MapOp::getIndexingMaps() {
1596  Builder builder(getContext());
1597  int64_t rank = getInit().getType().getRank();
1598  int64_t numIndexingMaps = getOperands().size();
1600  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1601 }
1602 
1603 void MapOp::getEffects(
1605  &effects) {
1606  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1607 }
1608 
1609 Speculation::Speculatability MapOp::getSpeculatability() {
1610  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1611 }
1612 
1613 //===----------------------------------------------------------------------===//
1614 // ReduceOp
1615 //===----------------------------------------------------------------------===//
1616 
1617 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1618  OpAsmSetValueNameFn setNameFn) {
1619  for (Value v : getRegionInputArgs())
1620  setNameFn(v, "in");
1621  for (Value v : getRegionOutputArgs())
1622  setNameFn(v, "init");
1623 }
1624 
1625 void ReduceOp::getAsmResultNames(
1626  function_ref<void(Value, StringRef)> setNameFn) {
1627  if (!getResults().empty())
1628  setNameFn(getResults().front(), "reduced");
1629 }
1630 
1631 void ReduceOp::build(
1632  OpBuilder &builder, OperationState &result, ValueRange inputs,
1633  ValueRange inits, ArrayRef<int64_t> dimensions,
1634  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1635  ArrayRef<NamedAttribute> attributes) {
1636  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1637  result.addAttributes(attributes);
1638 
1639  // Add output types for `RankedTensorType` output arguments.
1640  for (Value init : inits) {
1641  Type initType = init.getType();
1642  if (llvm::isa<RankedTensorType>(initType))
1643  result.addTypes(initType);
1644  }
1645 
1646  if (bodyBuild)
1647  buildGenericRegion(builder, result.location, *result.regions.front(),
1648  inputs, inits, bodyBuild);
1649 }
1650 
1651 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1652  int64_t inputRank =
1653  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1654  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1655  utils::IteratorType::parallel);
1656  for (int64_t reductionDim : getDimensions())
1657  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1658  return iteratorTypes;
1659 }
1660 
1661 ArrayAttr ReduceOp::getIndexingMaps() {
1662  int64_t inputRank =
1663  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1664  SmallVector<AffineMap> affineMaps(
1665  getNumDpsInputs(),
1667  AffineMap resultMap =
1669  .dropResults(getDimensions());
1670  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1671  affineMaps.push_back(resultMap);
1672  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1673 }
1674 
1675 void ReduceOp::getEffects(
1677  &effects) {
1678  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1679 }
1680 
1681 Speculation::Speculatability ReduceOp::getSpeculatability() {
1682  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1683 }
1684 
1685 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1686  NamedAttrList &attributes,
1687  StringRef attributeName) {
1688  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1689  return failure();
1690 
1691  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1692  return success();
1693 }
1694 
1695 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1696  std::optional<OperationName> payloadOpName;
1697  NamedAttrList payloadOpAttrs;
1698  if (succeeded(parser.parseOptionalLBrace())) {
1699  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1700  if (failed(operationName))
1701  return failure();
1702  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1703  return failure();
1704  payloadOpName = operationName.value();
1705  if (parser.parseRBrace())
1706  return failure();
1707  }
1708 
1709  if (parseDstStyleOp(
1710  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1711  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1712  }))
1713  return failure();
1714 
1715  if (payloadOpName.has_value()) {
1716  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1717  ArrayRef(result.operands), /*initFirst=*/true);
1718  } else {
1720  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1721  /*allowType=*/true, /*allowAttrs=*/true)) {
1722  return failure();
1723  }
1724 
1725  Region *body = result.addRegion();
1726  if (parser.parseRegion(*body, regionArgs))
1727  return failure();
1728  }
1729 
1730  return success();
1731 }
1732 
1733 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1734  ArrayRef<int64_t> attributeValue) {
1735  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1736 }
1737 
1738 void ReduceOp::print(OpAsmPrinter &p) {
1739  Block *mapper = getBody();
1740  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1741  if (payloadOp) {
1742  printShortForm(p, payloadOp);
1743  }
1744 
1745  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1746  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1747  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1748  if (!payloadOp) {
1749  // Print region if the payload op was not detected.
1750  p.increaseIndent();
1751  p.printNewline();
1752  p << "(";
1753  llvm::interleaveComma(mapper->getArguments(), p,
1754  [&](auto arg) { p.printRegionArgument(arg); });
1755  p << ") ";
1756 
1757  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1758  p.decreaseIndent();
1759  }
1760 }
1761 
1762 LogicalResult ReduceOp::verify() {
1763  ArrayRef<int64_t> dimensionsRef = getDimensions();
1764 
1765  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1766  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1767  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1768  return emitOpError() << "expects all inputs to have the same shapes. "
1769  "Shape at input-index "
1770  << i
1771  << " is not equal to the shape at input-index 0.";
1772  }
1773  }
1774  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1775  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1776  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1777  return emitOpError() << "expects all outputs to have the same shapes. "
1778  "Shape at output-index "
1779  << i
1780  << " is not equal to the shape at output-index 0.";
1781  }
1782  }
1783  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1784  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1785 
1786  DenseSet<int64_t> dimensionsToReduce;
1787  for (int64_t dimension : dimensionsRef) {
1788  if (dimension < 0 || dimension >= inputType.getRank()) {
1789  return emitOpError()
1790  << "dimensions for reduction should be in the range [0, "
1791  << inputType.getRank() - 1 << "].";
1792  }
1793  dimensionsToReduce.insert(dimension);
1794  }
1795 
1796  auto inputDims = inputType.getShape();
1797  auto initDims = initType.getShape();
1798 
1799  // Input dimensions that will be left after the reduction.
1800  SmallVector<int64_t> reducedInputDims;
1801  for (const auto &en : llvm::enumerate(inputDims)) {
1802  if (!dimensionsToReduce.count(en.index()))
1803  reducedInputDims.push_back(en.value());
1804  }
1805 
1806  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1807  return emitOpError() << "number of dimensions after reduction "
1808  << reducedInputDims.size()
1809  << " doesn't match the init rank "
1810  << initType.getRank();
1811  }
1812 
1813  if (reducedInputDims != initDims)
1814  return emitOpError() << "init dimensions [" << initDims
1815  << "] doesn't match input dimensions after reduction ["
1816  << reducedInputDims << "]";
1817 
1818  Block *block = getBody();
1819  if (block->getNumArguments() != this->getNumOperands())
1820  return emitOpError()
1821  << "mismatching number of operands and block arguments";
1822 
1823  // Check that the first block arguments match the element type of the inputs.
1824  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1825  Type inputElementType =
1826  llvm::cast<ShapedType>(input.getType()).getElementType();
1827  if (inputElementType != bbArg.getType())
1828  return emitOpError()
1829  << "input element type " << inputElementType
1830  << " does not match corresponding block argument type "
1831  << bbArg.getType();
1832  }
1833 
1834  // Check that the last block arguments match the element type of the outputs.
1835  for (auto [output, bbArg] : llvm::zip(
1836  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1837  auto outputElementType =
1838  llvm::cast<ShapedType>(output.getType()).getElementType();
1839  if (outputElementType != bbArg.getType())
1840  return emitOpError()
1841  << "output element type " << outputElementType
1842  << " does not match corresponding block argument type "
1843  << bbArg.getType();
1844  }
1845  return success();
1846 }
1847 
1848 //===----------------------------------------------------------------------===//
1849 // TransposeOp
1850 //===----------------------------------------------------------------------===//
1851 
1852 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1853  Region &region, ValueRange inputs,
1854  ValueRange outputs) {
1855  buildGenericRegion(builder, loc, region, inputs, outputs,
1856  [](OpBuilder &b, Location loc, ValueRange args) {
1857  if (!args.empty())
1858  b.create<linalg::YieldOp>(loc, args[0]);
1859  });
1860 }
1861 
1862 void TransposeOp::build(::mlir::OpBuilder &builder,
1863  ::mlir::OperationState &result, Value input, Value init,
1864  DenseI64ArrayAttr permutation,
1865  ArrayRef<NamedAttribute> attributes) {
1866  result.addOperands(input);
1867  result.addOperands(init);
1868  result.addAttribute(getPermutationAttrName(result.name), permutation);
1869  result.addAttributes(attributes);
1870 
1871  // Add output types for `RankedTensorType` output arguments.
1872  Type initType = init.getType();
1873  if (llvm::isa<RankedTensorType>(initType))
1874  result.addTypes(initType);
1875 
1876  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1877  init);
1878 }
1879 
1880 void TransposeOp::build(::mlir::OpBuilder &builder,
1881  ::mlir::OperationState &result, Value input, Value init,
1882  ArrayRef<int64_t> permutation,
1883  ArrayRef<NamedAttribute> attributes) {
1884  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1885  attributes);
1886 }
1887 
1888 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1889  if (failed(parseDstStyleOp(
1890  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1891  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1892  })))
1893  return failure();
1894 
1895  OpBuilder builder(parser.getContext());
1896  buildIdentityRegion(builder, result.location, *result.addRegion(),
1897  /*inputs=*/result.operands,
1898  /*outputs=*/{});
1899  return success();
1900 }
1901 
1902 void TransposeOp::getAsmResultNames(
1903  function_ref<void(Value, StringRef)> setNameFn) {
1904  if (!getResults().empty())
1905  setNameFn(getResults().front(), "transposed");
1906 }
1907 
1909  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1910  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1911  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1912 }
1913 
1914 LogicalResult TransposeOp::verify() {
1915  ArrayRef<int64_t> permutationRef = getPermutation();
1916 
1917  if (!isPermutationVector(permutationRef))
1918  return emitOpError("permutation is not valid");
1919 
1920  auto inputType = getInput().getType();
1921  auto initType = getInit().getType();
1922 
1923  int64_t rank = inputType.getRank();
1924 
1925  if (rank != initType.getRank())
1926  return emitOpError() << "input rank " << rank
1927  << " does not match init rank " << initType.getRank();
1928 
1929  if (rank != static_cast<int64_t>(permutationRef.size()))
1930  return emitOpError() << "size of permutation " << permutationRef.size()
1931  << " does not match the argument rank " << rank;
1932 
1933  auto inputDims = inputType.getShape();
1934  auto initDims = initType.getShape();
1935 
1936  for (int64_t i = 0; i < rank; ++i) {
1937  int64_t inputDim = inputDims[permutationRef[i]];
1938  int64_t initDim = initDims[i];
1939 
1940  if (inputDim != initDim) {
1941  return emitOpError() << "dim(result, " << i << ") = " << initDim
1942  << " doesn't match dim(input, permutation[" << i
1943  << "]) = " << inputDim;
1944  }
1945  }
1946 
1947  return success();
1948 }
1949 
1950 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1951  int64_t rank = getInit().getType().getRank();
1952  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1953 }
1954 
1955 ArrayAttr TransposeOp::getIndexingMaps() {
1956  Builder builder(getContext());
1957  int64_t rank = getInit().getType().getRank();
1958  return builder.getAffineMapArrayAttr(
1960  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1961  builder.getMultiDimIdentityMap(rank)});
1962 }
1963 
1964 void TransposeOp::getEffects(
1966  &effects) {
1967  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1968 }
1969 
1970 Speculation::Speculatability TransposeOp::getSpeculatability() {
1971  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1972 }
1973 
1974 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1976  // Only the tensor type is supported.
1977  if (!isa<TensorType>(getInput().getType()))
1978  return failure();
1979 
1980  // Single dimension transpose.
1981  if (getPermutation().size() == 0) {
1982  result.push_back(getInput());
1983  return success();
1984  }
1985  // Identity permutation.
1986  if (isIdentityPermutation(getPermutation())) {
1987  result.push_back(getInput());
1988  return success();
1989  }
1990 
1991  return failure();
1992 }
1993 
1994 /// Fold transpose with transpose.
1995 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1997 
1998  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1999  PatternRewriter &rewriter) const override {
2000  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2001  if (!defTransposeOp)
2002  return failure();
2003  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2004  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2005  SmallVector<int64_t> foldedPerms;
2006  foldedPerms.reserve(perms.size());
2007  for (int64_t perm : perms)
2008  foldedPerms.push_back(defPerms[perm]);
2009 
2010  rewriter.replaceOpWithNewOp<TransposeOp>(
2011  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2012  foldedPerms);
2013  return success();
2014  }
2015 };
2016 
2017 /// This pattern canonicalize transpose by swapping the order of
2018 /// broadcast and transpose:
2019 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2020 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2022 
2023  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2024  PatternRewriter &rewriter) const override {
2025  Value input = transposeOp.getInput();
2026  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2027  if (!input.hasOneUse() || !broadcastOp)
2028  return failure();
2029 
2030  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2031  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2032 
2033  // Get new perms and new dimensions.
2034  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2035  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2036  SmallVector<int64_t> resultDimensions;
2037  unsigned dimensionSize = dimensions.size();
2038  for (unsigned i = 0; i < dimensionSize; ++i)
2039  resultDimensions.push_back(invertPerm[dimensions[i]]);
2040 
2041  // Create transpose result.
2042  Value broadcastInput = broadcastOp.getInput();
2043  Location loc = transposeOp.getLoc();
2044  MLIRContext *ctx = transposeOp.getContext();
2046  auto broadcastInputTy =
2047  mlir::cast<RankedTensorType>(broadcastInput.getType());
2048  unsigned inputRank = broadcastInputTy.getRank();
2049  for (unsigned i = 0; i < inputRank; ++i) {
2050  if (broadcastInputTy.isDynamicDim(i)) {
2051  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2052  ->getResult(0));
2053  } else {
2054  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2055  broadcastInputTy.getDimSize(i)));
2056  }
2057  }
2058  SmallVector<OpFoldResult> transposeResultShapes =
2059  applyPermutation(dims, resultPerms);
2060  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2061  transposeOp.getLoc(), transposeResultShapes,
2062  broadcastInputTy.getElementType());
2063 
2064  // Create broadcast(transpose(input)).
2065  Value transposeResult =
2066  rewriter
2067  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2068  resultPerms)
2069  ->getResult(0);
2070  rewriter.replaceOpWithNewOp<BroadcastOp>(
2071  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2072  return success();
2073  }
2074 };
2075 
2076 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2077  MLIRContext *context) {
2079 }
2080 
2081 //===----------------------------------------------------------------------===//
2082 // BroadcastOp
2083 //===----------------------------------------------------------------------===//
2084 
2085 void BroadcastOp::build(::mlir::OpBuilder &builder,
2086  ::mlir::OperationState &result, Value input, Value init,
2087  DenseI64ArrayAttr dimensions,
2088  ArrayRef<NamedAttribute> attributes) {
2089  result.addOperands(input);
2090  result.addOperands(init);
2091  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2092  result.addAttributes(attributes);
2093 
2094  // Add output types for `RankedTensorType` output arguments.
2095  Type initType = init.getType();
2096  if (llvm::isa<RankedTensorType>(initType))
2097  result.addTypes(initType);
2098 
2099  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2100  init);
2101 }
2102 
2103 void BroadcastOp::build(::mlir::OpBuilder &builder,
2104  ::mlir::OperationState &result, Value input, Value init,
2105  ArrayRef<int64_t> dimensions,
2106  ArrayRef<NamedAttribute> attributes) {
2107  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2108  attributes);
2109 }
2110 
2111 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2112  if (failed(parseDstStyleOp(
2113  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2114  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2115  })))
2116  return failure();
2117 
2118  OpBuilder builder(parser.getContext());
2119  buildIdentityRegion(builder, result.location, *result.addRegion(),
2120  /*inputs=*/result.operands,
2121  /*outputs=*/{});
2122  return success();
2123 }
2124 
2125 void BroadcastOp::getAsmResultNames(
2126  function_ref<void(Value, StringRef)> setNameFn) {
2127  if (!getResults().empty())
2128  setNameFn(getResults().front(), "broadcasted");
2129 }
2130 
2132  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2133  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2134  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2135 }
2136 
2137 LogicalResult BroadcastOp::verify() {
2138  ArrayRef<int64_t> dimensionsRef = getDimensions();
2139 
2140  auto inputType = getInput().getType();
2141  auto initType = getInit().getType();
2142 
2143  int64_t inputRank = inputType.getRank();
2144  int64_t initRank = initType.getRank();
2145 
2146  auto inputShape = inputType.getShape();
2147  auto initShape = initType.getShape();
2148 
2149  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2150  return emitOpError() << "input rank plus added dimensions does not "
2151  "match init rank. input rank: "
2152  << inputRank
2153  << ", dimensions size: " << dimensionsRef.size()
2154  << ", init rank: " << initRank;
2155 
2156  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2157  if (dim < 0 || dim >= initRank)
2158  return emitOpError() << "dimension " << idx
2159  << " is out of range. expected range: [0, "
2160  << initRank - 1 << "], got: " << dim;
2161  }
2162 
2163  // Mapping from input dims to init dims.
2164  SmallVector<int64_t> dimMap;
2165  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2166  if (!llvm::is_contained(dimensionsRef, dim))
2167  dimMap.push_back(dim);
2168  }
2169 
2170  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2171  // This dimensions is mapped from the input. Init and input dims should
2172  // match.
2173  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2174  return emitOpError() << "input dim " << inputDimIdx
2175  << " should match init dim " << initDimIdx
2176  << ". input: " << inputShape[inputDimIdx]
2177  << ", init: " << initShape[initDimIdx];
2178  }
2179 
2180  return success();
2181 }
2182 
2183 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2184  int64_t rank = getInit().getType().getRank();
2185  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2186 }
2187 
2188 ArrayAttr BroadcastOp::getIndexingMaps() {
2189  Builder builder(getContext());
2190  int64_t rank = getInit().getType().getRank();
2191  return builder.getAffineMapArrayAttr(
2192  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2193  builder.getMultiDimIdentityMap(rank)});
2194 }
2195 
2196 void BroadcastOp::getEffects(
2198  &effects) {
2199  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2200 }
2201 
2202 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2203  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2204 }
2205 
2206 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2207  MLIRContext *context) {
2208  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2209 }
2210 
2211 //===----------------------------------------------------------------------===//
2212 // YieldOp
2213 //===----------------------------------------------------------------------===//
2214 
2216  if (getNumOperands() > 0)
2217  p << ' ' << getOperands();
2218  p.printOptionalAttrDict((*this)->getAttrs());
2219  if (getNumOperands() > 0)
2220  p << " : " << getOperandTypes();
2221 }
2222 
2223 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2225  SmallVector<Type, 2> types;
2226  SMLoc loc = parser.getCurrentLocation();
2227  return failure(parser.parseOperandList(opInfo) ||
2228  parser.parseOptionalAttrDict(result.attributes) ||
2229  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2230  parser.resolveOperands(opInfo, types, loc, result.operands));
2231 }
2232 
2233 // Check the operand number and types must match the element types of the
2234 // LinalgOp interface's shaped operands.
2235 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2236  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2237  return op.emitOpError("expected number of yield values (")
2238  << op.getNumOperands()
2239  << ") to match the number of inits / outs operands of the enclosing "
2240  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2241 
2242  for (OpOperand &opOperand : op->getOpOperands()) {
2243  OpOperand *outputOperand =
2244  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2245  Type elementType = outputOperand->get().getType();
2246  if (isa<MemRefType, RankedTensorType>(elementType))
2247  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2248  if (opOperand.get().getType() != elementType)
2249  return op.emitOpError("type of yield operand ")
2250  << (opOperand.getOperandNumber() + 1) << " ("
2251  << opOperand.get().getType() << ") doesn't match "
2252  << "the element type of the enclosing linalg.generic op ("
2253  << elementType << ")";
2254  }
2255  return success();
2256 }
2257 
2258 LogicalResult linalg::YieldOp::verify() {
2259  auto *parentOp = (*this)->getParentOp();
2260  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2261  return emitOpError("expected single non-empty parent region");
2262 
2263  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2264  return verifyYield(*this, linalgOp);
2265 
2266  return emitOpError("expected parent op with LinalgOp interface");
2267 }
2268 
2269 //===----------------------------------------------------------------------===//
2270 // IndexOp
2271 //===----------------------------------------------------------------------===//
2272 
2273 LogicalResult IndexOp::verify() {
2274  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2275  if (!linalgOp)
2276  return emitOpError("expected parent op with LinalgOp interface");
2277  if (linalgOp.getNumLoops() <= getDim())
2278  return emitOpError("expected dim (")
2279  << getDim() << ") to be lower than the number of loops ("
2280  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2281  return success();
2282 }
2283 
2284 /////// Operations corresponding to library calls defined with Tablegen ////////
2285 
2286 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2287 
2288 #define GET_OP_CLASSES
2289 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2290 
2291 #define GET_OP_CLASSES
2292 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2293 #define GET_OP_CLASSES
2294 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2295 
2296 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2297  unsigned rank,
2298  MLIRContext *context) {
2299  if (maybeMap)
2300  return *maybeMap;
2301  if (rank == 0)
2302  return AffineMap::get(context);
2303  return AffineMap::getMultiDimIdentityMap(rank, context);
2304 }
2305 
2307 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2308  MLIRContext *context) {
2310  res.reserve(num);
2311  for (unsigned i = 0; i < num; ++i)
2312  res.push_back(getAffineDimExpr(startIdx++, context));
2313  return res;
2314 }
2315 
2318  auto rangeA = llvm::make_range(a.begin(), a.end());
2319  auto rangeB = llvm::make_range(b.begin(), b.end());
2320  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2321  return llvm::to_vector<4>(concatRanges);
2322 }
2323 
2324 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2325  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2326  ss << "view";
2327  for (auto size : memref.getShape())
2328  if (size < 0)
2329  ss << "sx";
2330  else
2331  ss << size << "x";
2332  if (failed(appendMangledType(ss, memref.getElementType())))
2333  return failure();
2334  if (auto as = memref.getMemorySpace()) {
2335  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2336  ss << "as" << attr.getInt();
2337  else
2338  return failure();
2339  }
2340  return success();
2341  }
2342  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2343  ss << "vector";
2344  llvm::interleave(
2345  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2346  if (failed(appendMangledType(ss, vec.getElementType())))
2347  return failure();
2348  return success();
2349  }
2350  if (t.isSignlessIntOrIndexOrFloat()) {
2351  ss << t;
2352  return success();
2353  }
2354  return failure();
2355 }
2356 
2358  assert(isa<LinalgOp>(op));
2359  std::string name(op->getName().getStringRef().str());
2360  std::string fun = "";
2361  for (NamedAttribute kv : op->getAttrs()) {
2362  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2363  fun = stringifyEnum(ufa.getValue()).str() + "_";
2364  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2365  fun = stringifyEnum(bfa.getValue()).str() + "_";
2366  }
2367  }
2368  name.reserve(128);
2369  std::replace(name.begin(), name.end(), '.', '_');
2370  llvm::raw_string_ostream ss(name);
2371  ss << "_" << fun;
2372  for (Type t : op->getOperandTypes()) {
2373  if (failed(appendMangledType(ss, t)))
2374  return std::string();
2375  ss << "_";
2376  }
2377  name.pop_back();
2378  return name;
2379 }
2380 
2381 //===----------------------------------------------------------------------===//
2382 // Canonicalizers and Folders.
2383 //===----------------------------------------------------------------------===//
2384 
2385 namespace {
2386 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2388 
2389  LogicalResult matchAndRewrite(LinalgOp op,
2390  PatternRewriter &rewriter) const override {
2391  for (OpOperand &opOperand : op->getOpOperands()) {
2392  // Linalg "inputs" may be either tensor or memref type.
2393  // tensor<0xelt_type> is a convention that may not always mean
2394  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2395  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2396  if (!mt)
2397  continue;
2398  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2399  rewriter.eraseOp(op);
2400  return success();
2401  }
2402  }
2403  return failure();
2404  }
2405 };
2406 
2407 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2408 /// result that is more static than the linalg op.
2409 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2411 
2412  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2413  PatternRewriter &rewriter) const override {
2414  if (!tensor::canFoldIntoProducerOp(castOp))
2415  return failure();
2416 
2417  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2418  if (!linalgOp)
2419  return failure();
2420 
2421  // Cast can be in conditionally reachable region, if which case folding will
2422  // generate invalid code. Only conservatively fold ops in same block for
2423  // now.
2424  if (castOp->getBlock() != linalgOp->getBlock())
2425  return failure();
2426 
2427  OpBuilder::InsertionGuard guard(rewriter);
2428  rewriter.setInsertionPoint(linalgOp);
2429 
2430  Location loc = linalgOp.getLoc();
2431  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2432  unsigned resultNumber = resultValue.getResultNumber();
2433  auto resultType =
2434  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2435  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2436  // going from a more dynamic shape to a less dynamic shape. If the producer
2437  // for this cast, i.e. producer of the out operand, is also an operation
2438  // that folds with tensor.cast consumer (like this pattern), the cast will
2439  // continue to propagate as far up the stack as it can go.
2440  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2441  Value newOperand =
2442  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2443  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2444  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2445  linalgOp.getDpsInits().end());
2446  outputOperands[resultNumber] = newOperand;
2447  newOperands.append(outputOperands.begin(), outputOperands.end());
2448 
2449  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2450  linalgOp->result_type_end());
2451  resultTypes[resultNumber] = resultType;
2452  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2453 
2454  // Create a tensor.cast operation back to the original type.
2455  Value castBack = rewriter.create<tensor::CastOp>(
2456  loc, resultValue.getType(), newOp->getResult(resultNumber));
2457 
2458  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2459  results[resultNumber] = castBack;
2460  rewriter.replaceOp(linalgOp, results);
2461  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2462  return success();
2463  }
2464 };
2465 
2466 /// For each of the operand in `operands` this function maps the static sizes of
2467 /// dimensions to their affine dim expressions.
2468 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2469  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2470  for (OpOperand &opOperand : operands) {
2471  if (linalgOp.isScalar(&opOperand))
2472  continue;
2473  Value src = opOperand.get();
2474  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2475  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2476 
2477  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2478  // `tensor.cast` operation and source of the cast operation has a static
2479  // shape, then assign it to the `sourceShape`.
2480  auto *parentOp = src.getDefiningOp();
2481  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2482  if (parentOp) {
2483  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2484  Value castSource = castOp.getSource();
2485  auto castSourceType =
2486  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2487  if (castSourceType && castSourceType.hasStaticShape())
2488  sourceShape = castSourceType.getShape();
2489  }
2490  }
2491 
2492  // If the source shape's dimension has a static shape, map the affine dim
2493  // expression to the known static size.
2494  for (unsigned i = 0; i < sourceShape.size(); i++) {
2495  if (sourceType.isDynamicDim(i))
2496  continue;
2497  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2498  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2499  }
2500  }
2501 }
2502 
2503 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2504 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2505 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2506 /// change then `changeNeeded` is false and same operand is added in the
2507 /// `newOperands` list.
2508 static void createNewOperandWithStaticSizes(
2509  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2510  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2511  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2512  bool &changeNeeded) {
2513  Value src = opOperand->get();
2514  newOperands.push_back(src);
2515  if (linalgOp.isScalar(opOperand))
2516  return;
2517  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2518  Type resultType = sourceType;
2519  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2520  resultTypes.push_back(resultType);
2521  return;
2522  }
2523  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2524  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2525  SmallVector<int64_t> newShape;
2526  // If operand is updated with new shape, `newOperandNeeded` will be
2527  // true.
2528  bool newOperandNeeded = false;
2529  for (unsigned i = 0; i < sourceShape.size(); i++) {
2530  int64_t dimShape = sourceShape[i];
2531  AffineExpr dimExpr = sourceMap.getResult(i);
2532  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2533  newShape.push_back(dimShape);
2534  continue;
2535  }
2536  // Dimension has a dynamic shape and corresponding affine dim
2537  // expression is present in the map. So assign the size for the
2538  // given affine dim expression to the dimension.
2539  newShape.push_back(affineExprToSize[dimExpr]);
2540  newOperandNeeded = true;
2541  }
2542  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2543  if (newOperandNeeded) {
2544  changeNeeded = true;
2545  // Get the new operand value given its size and element type by
2546  // casting it.
2547  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2548  unsigned index = opOperand->getOperandNumber();
2549  newOperands[index] = newOperand;
2550  }
2551  if (linalgOp.isDpsInit(opOperand))
2552  resultTypes.push_back(resultType);
2553 }
2554 
2555 /// Static shapes for the operands can be inferred if any one of the operands
2556 /// have a static shape. This can be done by referring to the affine dim
2557 /// expressions for the operand.
2558 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2560 
2561  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2562  PatternRewriter &rewriter) const override {
2563  if (!linalgOp.hasPureTensorSemantics())
2564  return failure();
2565 
2566  // Maps must be projected permutations.
2567  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2568  return !map.isProjectedPermutation();
2569  }))
2570  return failure();
2571 
2572  // Maps affine dim expressions to the static size of that dimension.
2573  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2574  Location loc = linalgOp.getLoc();
2575 
2576  // For each of the affine dim expression, check if the size is known. If
2577  // known add that in the map.
2578  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2579 
2580  SmallVector<Value> newOperands;
2581  SmallVector<Type> resultTypes;
2582 
2583  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2584  // change in their types.
2585  bool changeNeeded = false;
2586  newOperands.reserve(linalgOp->getNumOperands());
2587  resultTypes.reserve(linalgOp.getNumDpsInits());
2588 
2589  // Iterate over all the operands and update the static sizes.
2590  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2591  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2592  affineExprToSize, linalgOp, newOperands,
2593  resultTypes, changeNeeded);
2594  }
2595 
2596  // If the generic op has all the required static information, no
2597  // canonicalization needed.
2598  if (!changeNeeded)
2599  return failure();
2600 
2601  // Clone op.
2602  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2603  SmallVector<Value> replacements;
2604  replacements.reserve(newOp->getNumResults());
2605  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2606  Value newResult = std::get<1>(it);
2607  Value oldResult = std::get<0>(it);
2608  Type newType = newResult.getType();
2609  Type oldType = oldResult.getType();
2610  replacements.push_back(
2611  (newType != oldType)
2612  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2613  : newResult);
2614  }
2615  rewriter.replaceOp(linalgOp, replacements);
2616  return success();
2617  }
2618 };
2619 
2620 } // namespace
2621 
2622 // All named ops canonicalizers and folders are auto-generated in the
2623 // .cpp.inc.
2624 
2625 //===----------------------------------------------------------------------===//
2626 // SoftmaxOp
2627 //===----------------------------------------------------------------------===//
2628 
2629 LogicalResult SoftmaxOp::verify() {
2630  ShapedType inputType = getInputOperandType();
2631  ShapedType outputType = getOutputOperandType();
2632 
2633  ArrayRef<int64_t> inputShape = inputType.getShape();
2634  ArrayRef<int64_t> outputShape = outputType.getShape();
2635  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2636  return emitOpError("incompatible output shape");
2637 
2638  int64_t inputRank = getInputOperandRank();
2639  int64_t dimension = getDimension();
2640  if ((dimension < 0) || (dimension >= inputRank))
2641  return emitOpError("incorrect dimension specified");
2642 
2643  return success();
2644 }
2645 
2646 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2647  int64_t operandRank = getInputOperandRank();
2648  SmallVector<Range> loopBounds(operandRank);
2649  Location loc = getLoc();
2650  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2651  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2652  Value source = getInput();
2653  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2654  loopBounds[dim].offset = zero;
2655  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2656  loopBounds[dim].stride = one;
2657  }
2658  return loopBounds;
2659 }
2660 
2661 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2662  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2663  utils::IteratorType::parallel);
2664  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2665  return iteratorTypes;
2666 }
2667 
2668 FailureOr<TilingResult>
2670  ArrayRef<OpFoldResult> offsets,
2671  ArrayRef<OpFoldResult> sizes) {
2672  int64_t rank = getInputOperandRank();
2673  auto oneAttr = builder.getI64IntegerAttr(1);
2674  SmallVector<OpFoldResult> strides(rank, oneAttr);
2675  SmallVector<Value> tiledOperands;
2676  Operation *inputSlice =
2677  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2678  if (!inputSlice) {
2679  return emitOpError("failed to compute input slice");
2680  }
2681  tiledOperands.emplace_back(inputSlice->getResult(0));
2682  Operation *outputSlice =
2683  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2684  if (!outputSlice) {
2685  return emitOpError("failed to compute output slice");
2686  }
2687  tiledOperands.emplace_back(outputSlice->getResult(0));
2688 
2689  SmallVector<Type, 4> resultTypes;
2690  if (hasPureTensorSemantics())
2691  resultTypes.push_back(tiledOperands[1].getType());
2692  Operation *tiledOp =
2693  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2694 
2695  return TilingResult{
2696  {tiledOp},
2697  SmallVector<Value>(tiledOp->getResults()),
2698  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2699 }
2700 
2701 LogicalResult SoftmaxOp::getResultTilePosition(
2702  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2703  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2704  SmallVector<OpFoldResult> &resultSizes) {
2705  if (resultNumber == 0) {
2706  resultOffsets.assign(offsets.begin(), offsets.end());
2707  resultSizes.assign(sizes.begin(), sizes.end());
2708  return success();
2709  }
2710  return failure();
2711 }
2712 
2713 // cast(dynamic) -> static.
2714 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2715  return memref::foldMemRefCast(*this);
2716 }
2717 
2718 LogicalResult
2720  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2722  Location loc = getOperation()->getLoc();
2723  IRRewriter rewriter(b);
2724  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2725  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2726  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2727  if (!outputShapedType.isDynamicDim(dim)) {
2728  // Static dim: Return IntegerAttr.
2729  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2730  } else {
2731  // Dynamic dim: Return Value.
2732  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2733  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2734  }
2735  }
2736  reifiedReturnShapes.emplace_back(std::move(shapes));
2737  return success();
2738 }
2739 
2740 void SoftmaxOp::getEffects(
2742  &effects) {
2743  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2744  if (!llvm::isa<MemRefType>(operand.getType()))
2745  continue;
2746  effects.emplace_back(MemoryEffects::Read::get(),
2747  &getOperation()->getOpOperand(index), /*stage=*/0,
2748  /*effectOnFullRegion=*/true,
2750  }
2751 
2752  for (OpOperand &operand : getDpsInitsMutable()) {
2753  if (!llvm::isa<MemRefType>(operand.get().getType()))
2754  continue;
2755  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2756  /*effectOnFullRegion=*/true,
2758  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2759  /*effectOnFullRegion=*/true,
2761  }
2762 }
2763 
2764 // Helper functions for softmax decomposition.
2765 // @{
2766 
2767 // Helper function to produce the iterator types (reduction or parallel) and
2768 // affine maps for the iterators used in the decomposition of softmax.
2769 // This method creates:
2770 // If allParallel == true:
2771 // - iterator type: {parallel, ..., parallel}
2772 // - affine maps:
2773 // -- identity with inputRank dimensions.
2774 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2775 // where N == inputRank.
2776 //
2777 // If allParallel == false:
2778 // - iterator type at dim(i) == parallel for i != \p dim and
2779 // dim(dim) == reduction.
2780 // - affine map:
2781 // -- identity with inputRank dimensions.
2782 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2783 // where N == inputRank.
2784 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2786  int64_t dim, bool allParallel = false) {
2787  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2788  utils::IteratorType::parallel);
2789  if (!allParallel)
2790  iteratorTypes[dim] = utils::IteratorType::reduction;
2791  MLIRContext *ctxt = builder.getContext();
2792  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2793  SmallVector<AffineExpr, 2> affineExprs;
2794  for (int i = 0; i < inputRank; i++) {
2795  if (i != dim)
2796  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2797  }
2798  auto reductionMap =
2799  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2800  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2801  return std::make_tuple(iteratorTypes, indexingMaps);
2802 }
2803 
2804 // Helper function to produce a linalg.generic that computes a reduction on
2805 // dimension \p dim with the operation type \p T.
2806 template <typename T>
2807 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2808  int64_t dim) {
2809  auto inputType = cast<ShapedType>(input.getType());
2810  ArrayRef<int64_t> inputShape = inputType.getShape();
2811  int64_t inputRank = inputShape.size();
2812  auto [iteratorTypes, indexingMaps] =
2813  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2814  assert(indexingMaps.size() == 2 &&
2815  "We should have two maps: 1 for the input, 1 for the output");
2816  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2817 
2818  auto genericOp = builder.create<linalg::GenericOp>(
2819  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2820  [&](OpBuilder &b, Location loc, ValueRange args) {
2821  Value result = b.create<T>(loc, args[0], args[1]);
2822  b.create<linalg::YieldOp>(loc, result);
2823  });
2824  return genericOp.getResult(0);
2825 }
2826 
2827 /// Produce a linalg generic that computes the second step of the softmax
2828 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2829 /// on dimension \p dim.
2830 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2831  Value max, Value output, int64_t dim) {
2832  auto inputType = cast<ShapedType>(input.getType());
2833  ArrayRef<int64_t> inputShape = inputType.getShape();
2834  int64_t inputRank = inputShape.size();
2835  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2836  builder, inputRank, dim, /*allParallel=*/true);
2837  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2838  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2839  // Add the affine map for the output argument.
2840  indexingMaps.push_back(indexingMaps[0]);
2841  auto genericOp = builder.create<linalg::GenericOp>(
2842  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2843  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2844  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2845  Value result = b.create<math::ExpOp>(loc, diff);
2846  b.create<linalg::YieldOp>(loc, result);
2847  });
2848  return genericOp.getResult(0);
2849 }
2850 
2851 /// Produce a linalg generic that computes the final step of the softmax
2852 /// decomposition.
2853 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2854 /// yield n / d
2855 /// }
2856 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2857  Value denominator, Value output, int64_t dim) {
2858  auto inputType = cast<ShapedType>(numerator.getType());
2859  ArrayRef<int64_t> inputShape = inputType.getShape();
2860  int64_t inputRank = inputShape.size();
2861  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2862  builder, inputRank, dim, /*allParallel=*/true);
2863  assert(indexingMaps.size() == 2 &&
2864  "We should have one map for each input (2)");
2865  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2866  // Add the affine map for the output tensor.
2867  indexingMaps.push_back(indexingMaps[0]);
2868  auto genericOp = builder.create<linalg::GenericOp>(
2869  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2870  indexingMaps, iteratorTypes,
2871  [&](OpBuilder &b, Location loc, ValueRange args) {
2872  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2873  b.create<linalg::YieldOp>(loc, result);
2874  });
2875  return genericOp.getResult(0);
2876 }
2877 // @} End helper functions for softmax decomposition.
2878 
2879 /// Given an N-dimensional tensor x, this method converts
2880 /// softmax(x) to the following sequence of operations:
2881 ///
2882 /// 1. Compute the max of x along dimension d. This results
2883 /// in a N-1 dimensional tensor m.
2884 /// m = max(x, dim = d)
2885 ///
2886 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2887 /// a N dimensional tensor z.
2888 /// z = exp(x - m)
2889 ///
2890 /// 3. Compute the sum of z along dimension d. This results in
2891 /// a N-1 dimensional tensor l.
2892 /// l = sum(z, dim = d)
2893 ///
2894 /// 4. Divide z and l. This gives the N-dimensional softmax.
2895 /// softmax = z / l
2896 ///
2897 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2898  OpBuilder::InsertionGuard guard(b);
2899  b.setInsertionPoint(*this);
2900  Location loc = getLoc();
2901  Value input = getInput();
2902  ShapedType inputType = getInputOperandType();
2903  Type elementType = inputType.getElementType();
2904  int64_t reductionDim = getDimension();
2905  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2906  Value output = getOutput();
2907  dims.erase(dims.begin() + reductionDim);
2908  // Step 1: Compute max along dim.
2909  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2910  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
2911  elementType, b, loc,
2912  /*useOnlyFiniteValue=*/true);
2913  Value neutralForMaxFInit =
2914  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2915  .result();
2916  Value max =
2917  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2918 
2919  // Step 2: Subtract max from input and exponentiate.
2920  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2921 
2922  // Step 3: Compute sum along dim.
2923  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2924  b, loc, /*useOnlyFiniteValue=*/true);
2925  Value zeroInit =
2926  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2927  Value denominator =
2928  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2929 
2930  // Step 4: Compute softmax.
2931  Value result =
2932  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2933  return SmallVector<Value>{result};
2934 }
2935 
2936 //===----------------------------------------------------------------------===//
2937 // WinogradFilterTransformOp
2938 //===----------------------------------------------------------------------===//
2939 
2940 LogicalResult WinogradFilterTransformOp::verify() {
2941  auto filterType = cast<ShapedType>(getFilter().getType());
2942  ArrayRef<int64_t> filterShape = filterType.getShape();
2943  int64_t filterH = filterShape[getFilterHDim()];
2944  int64_t filterW = filterShape[getFilterWDim()];
2945  int64_t r = getR();
2946  int64_t m = getM();
2947 
2948  if (filterH != r && filterH != 1)
2949  return emitOpError("expect filter height either equals to r or 1");
2950  if (filterW != r && filterW != 1)
2951  return emitOpError("expect filter width either equals to r or 1");
2952  if (filterH == 1 && filterW == 1)
2953  return emitOpError("expect either filter height or width equals to r");
2954 
2955  SmallVector<int64_t> expectedOutputShape;
2956  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2957  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2958  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2959  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2960 
2961  auto outputType = cast<ShapedType>(getOutput().getType());
2962  ArrayRef<int64_t> outputShape = outputType.getShape();
2963  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2964  return emitOpError("the output shape is not expected");
2965  }
2966  return success();
2967 }
2968 
2970 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2971  Location loc = getLoc();
2972  IntegerAttr zeroAttr = builder.getIndexAttr(0);
2973  IntegerAttr oneAttr = builder.getIndexAttr(1);
2974  Value filter = getFilter();
2975  int64_t filterRank = getFilterOperandRank();
2976  SmallVector<Range> loopBounds(filterRank);
2977  for (unsigned dim = 0; dim < filterRank; ++dim) {
2978  loopBounds[dim].offset = zeroAttr;
2979  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2980  loopBounds[dim].stride = oneAttr;
2981  }
2982  return loopBounds;
2983 }
2984 
2986 WinogradFilterTransformOp::getLoopIteratorTypes() {
2987  int64_t filterRank = getFilterOperandRank();
2988  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2989  utils::IteratorType::parallel);
2990  return iteratorTypes;
2991 }
2992 
2994  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2995  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2996  SmallVector<OpFoldResult> &resultSizes) {
2997  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2998  ShapedType filterType = getFilterOperandType();
2999  ArrayRef<int64_t> filterShape = filterType.getShape();
3000  int64_t filterH = filterShape[getFilterHDim()];
3001  int64_t filterW = filterShape[getFilterWDim()];
3002  int64_t m = getM();
3003  int64_t r = getR();
3004  int64_t alpha = m + r - 1;
3005  int64_t alphaH = filterH != 1 ? alpha : 1;
3006  int64_t alphaW = filterW != 1 ? alpha : 1;
3007  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3008  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3009 
3010  resultOffsets.append(
3011  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3012  resultSizes.append(
3013  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3014 
3015  return success();
3016 }
3017 
3018 /// Implement tiling for winograd_filter_transform
3019 /// The input of winograd_filter_transform is (F, KH, KW, C).
3020 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3021 /// Users can specify the tile sizes of F and C.
3022 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3023 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3025  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3026  ArrayRef<OpFoldResult> sizes) {
3027  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3028  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3029  ShapedType filterType = getFilterOperandType();
3030  ArrayRef<int64_t> filterShape = filterType.getShape();
3031  int64_t filterH = filterShape[getFilterHDim()];
3032  int64_t filterW = filterShape[getFilterWDim()];
3033  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3034  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3035  SmallVector<Value> tiledOperands;
3036  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3037 
3038  sliceOffsets.append(
3039  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3040  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3041  sizes[getFilterCDim()]});
3042  int64_t filterRank = getFilterOperandRank();
3043  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3044  Location loc = getLoc();
3045  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3046  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3047  tiledOperands.emplace_back(filterSlice);
3048 
3049  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3050  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3051  resultSizes)))
3052  return failure();
3053 
3054  int64_t outputRank = getOutputOperandRank();
3055  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3056  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3057  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3058  tiledOperands.emplace_back(outputSlice);
3059 
3060  SmallVector<Type> resultTypes;
3061  resultTypes.push_back(tiledOperands[1].getType());
3062  Operation *tiledOp =
3063  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3064 
3065  return TilingResult{
3066  {tiledOp},
3067  SmallVector<Value>(tiledOp->getResults()),
3068  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3069 }
3070 
3071 //===----------------------------------------------------------------------===//
3072 // WinogradInputTransformOp
3073 //===----------------------------------------------------------------------===//
3074 
3075 LogicalResult WinogradInputTransformOp::verify() {
3076  auto inputType = cast<ShapedType>(getInput().getType());
3077  ArrayRef<int64_t> inputShape = inputType.getShape();
3078  int64_t inputH = inputShape[getInputHDim()];
3079  int64_t inputW = inputShape[getInputWDim()];
3080  int m = getM();
3081  int r = getR();
3082  int64_t tileSize = m + r - 1;
3083 
3084  auto outputType = cast<ShapedType>(getOutput().getType());
3085  ArrayRef<int64_t> outputShape = outputType.getShape();
3086  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3087  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3088 
3089  SmallVector<int64_t> expectedOutputShape(6, inputH);
3090  if (ShapedType::isDynamic(inputH)) {
3091  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3092  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3093  } else {
3094  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3095  expectedOutputShape[getOutputTileHDim()] =
3096  leftTransform ? (inputH - (r - 1)) / m : inputH;
3097  }
3098  if (ShapedType::isDynamic(inputW)) {
3099  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3100  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3101  } else {
3102  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3103  expectedOutputShape[getOutputTileWDim()] =
3104  rightTransform ? (inputW - (r - 1)) / m : inputW;
3105  }
3106  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3107  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3108 
3109  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3110  return emitOpError("the output shape is not expected");
3111  }
3112  return success();
3113 }
3114 
3116 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3117  Location loc = getLoc();
3118  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3119  IntegerAttr oneAttr = builder.getIndexAttr(1);
3120  Value output = getOutput();
3121  int64_t outputRank = getOutputOperandRank();
3122  SmallVector<Range> loopBounds(outputRank);
3123  for (unsigned dim = 0; dim < outputRank; ++dim) {
3124  loopBounds[dim].offset = zeroAttr;
3125  // alphaH, alphaW, tileH, tileW, N, C
3126  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3127  loopBounds[dim].stride = oneAttr;
3128  }
3129  return loopBounds;
3130 }
3131 
3133 WinogradInputTransformOp::getLoopIteratorTypes() {
3134  int64_t outputRank = getOutputOperandRank();
3135  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3136  utils::IteratorType::parallel);
3137  return iteratorTypes;
3138 }
3139 
3141  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3142  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3143  SmallVector<OpFoldResult> &resultSizes) {
3144  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3145  ShapedType outputType = getOutputOperandType();
3146  ArrayRef<int64_t> outputShape = outputType.getShape();
3147  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3148  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3149 
3150  int64_t m = getM();
3151  int64_t r = getR();
3152  int64_t alpha = m + r - 1;
3153  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3154  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3155 
3156  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3157  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3158 
3159  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3160  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3161  offsets[getOutputCDim()]});
3162  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3163  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3164  sizes[getOutputCDim()]});
3165 
3166  return success();
3167 }
3168 
3169 /// Implement tiling for winograd_input_transform
3170 /// The input of winograd_input_transform is (N, H, W, C).
3171 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3172 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3173 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3174 /// the values for the sizes of tileH, tileW, N, C for one tile.
3175 FailureOr<TilingResult>
3177  ArrayRef<OpFoldResult> offsets,
3178  ArrayRef<OpFoldResult> sizes) {
3179  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3180  int64_t m = getM();
3181  int64_t r = getR();
3182 
3183  ShapedType outputType = getOutputOperandType();
3184  ArrayRef<int64_t> outputShape = outputType.getShape();
3185  int64_t alphaH = outputShape[getOutputAlphaHDim()];
3186  int64_t alphaW = outputShape[getOutputAlphaWDim()];
3187 
3188  Location loc = getLoc();
3189  MLIRContext *context = builder.getContext();
3190  auto identityAffineMap =
3191  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3192  auto offsetAffineMap =
3193  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3194  Value mappedOffsetH = affine::makeComposedAffineApply(
3195  builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3196  offsets[getOutputTileHDim()]);
3197  Value mappedOffsetW = affine::makeComposedAffineApply(
3198  builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3199  offsets[getOutputTileWDim()]);
3200  auto sizeAffineMap = AffineMap::get(
3201  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3202  Value mappedSizeH = affine::makeComposedAffineApply(
3203  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3204  Value mappedSizeW = affine::makeComposedAffineApply(
3205  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3206 
3207  SmallVector<Value> tiledOperands;
3208  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3209 
3210  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3211  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3212  sliceOffsets.append(
3213  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3214  OpFoldResult sizeH =
3215  alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3216  OpFoldResult sizeW =
3217  alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3218  sliceSizes.append(
3219  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3220  int64_t inputRank = getInputOperandRank();
3221  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3222  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3223  loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3224  tiledOperands.emplace_back(inputSlice);
3225 
3226  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3227  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3228  resultSizes)))
3229  return failure();
3230 
3231  int64_t outputRank = getOutputOperandRank();
3232  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3233  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3234  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3235  tiledOperands.emplace_back(outputSlice);
3236 
3237  SmallVector<Type> resultTypes;
3238  resultTypes.push_back(tiledOperands[1].getType());
3239  Operation *tiledOp =
3240  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3241 
3242  return TilingResult{
3243  {tiledOp},
3244  SmallVector<Value>(tiledOp->getResults()),
3245  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3246 }
3247 
3248 //===----------------------------------------------------------------------===//
3249 // WinogradOutputTransformOp
3250 //===----------------------------------------------------------------------===//
3251 
3252 LogicalResult WinogradOutputTransformOp::verify() {
3253  auto valueType = cast<ShapedType>(getValue().getType());
3254  ArrayRef<int64_t> valueShape = valueType.getShape();
3255  int64_t valueH = valueShape[getValueAlphaHDim()];
3256  int64_t valueW = valueShape[getValueAlphaWDim()];
3257  int64_t valueTileH = valueShape[getValueTileHDim()];
3258  int64_t valueTileW = valueShape[getValueTileWDim()];
3259  int m = getM();
3260  int r = getR();
3261  bool leftTransform = valueH != 1;
3262  bool rightTransform = valueW != 1;
3263 
3264  int64_t outputRank = getOutputOperandRank();
3265  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3266  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3267  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3268  } else {
3269  if (valueH != (leftTransform ? m + r - 1 : 1))
3270  return emitOpError("expect input height equals to input tile size");
3271  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3272  }
3273  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3274  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3275  } else {
3276  if (valueW != (rightTransform ? m + r - 1 : 1))
3277  return emitOpError("expect input width equals to input tile size");
3278  expectedOutputShape[getOutputWDim()] =
3279  (rightTransform ? m : 1) * valueTileW;
3280  }
3281  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3282  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3283 
3284  auto outputType = cast<ShapedType>(getOutput().getType());
3285  ArrayRef<int64_t> outputShape = outputType.getShape();
3286  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3287  return emitOpError("the output shape is not expected");
3288  }
3289  return success();
3290 }
3291 
3293 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3294  Location loc = getLoc();
3295  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3296  IntegerAttr oneAttr = builder.getIndexAttr(1);
3297  Value value = getValue();
3298  int64_t valueRank = getValueOperandRank();
3299  SmallVector<Range> loopBounds(valueRank);
3300  for (unsigned dim = 0; dim < valueRank; ++dim) {
3301  loopBounds[dim].offset = zeroAttr;
3302  // alphaH, alphaW, tileH, tileW, N, F
3303  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3304  loopBounds[dim].stride = oneAttr;
3305  }
3306  return loopBounds;
3307 }
3308 
3310 WinogradOutputTransformOp::getLoopIteratorTypes() {
3311  int64_t valueRank = getValueOperandRank();
3312  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3313  utils::IteratorType::parallel);
3314  return iteratorTypes;
3315 }
3316 
3318  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3319  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3320  SmallVector<OpFoldResult> &resultSizes) {
3321  int64_t m = getM();
3322 
3323  Location loc = getLoc();
3324  MLIRContext *context = builder.getContext();
3325  auto identityAffineMap =
3326  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3327  auto affineMap =
3328  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3329 
3330  ShapedType valueType = getValueOperandType();
3331  ArrayRef<int64_t> valueShape = valueType.getShape();
3332  int64_t valueH = valueShape[0];
3333  int64_t valueW = valueShape[1];
3334  Value mappedOffsetH = affine::makeComposedAffineApply(
3335  builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3336  offsets[getValueTileHDim()]);
3337  Value mappedOffsetW = affine::makeComposedAffineApply(
3338  builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3339  offsets[getValueTileWDim()]);
3340  Value mappedSizeH = affine::makeComposedAffineApply(
3341  builder, loc, affineMap, sizes[getValueTileHDim()]);
3342  Value mappedSizeW = affine::makeComposedAffineApply(
3343  builder, loc, affineMap, sizes[getValueTileWDim()]);
3344 
3345  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3346  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3347  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3348  OpFoldResult sizeH =
3349  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3350  OpFoldResult sizeW =
3351  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3352 
3353  resultOffsets.append(
3354  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3355  resultSizes.append(
3356  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3357  return success();
3358 }
3359 
3360 /// Implement tiling for winograd_output_transform
3361 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3362 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3363 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3364 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3365 /// for the sizes of tileH, tileW, N, F for one tile.
3367  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3368  ArrayRef<OpFoldResult> sizes) {
3369  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3370  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3371  Location loc = getLoc();
3372  SmallVector<Value> tiledOperands;
3373  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3374 
3375  ShapedType valueType = getValueOperandType();
3376  ArrayRef<int64_t> valueShape = valueType.getShape();
3377  int64_t alphaH = valueShape[getValueAlphaHDim()];
3378  int64_t alphaW = valueShape[getValueAlphaWDim()];
3379  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3380  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3381 
3382  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3383  offsets[getValueTileWDim()], offsets[getValueNDim()],
3384  offsets[getValueFDim()]});
3385  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3386  sizes[getValueTileWDim()], sizes[getValueNDim()],
3387  sizes[getValueFDim()]});
3388  int64_t valueRank = getValueOperandRank();
3389  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3390  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3391  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3392  tiledOperands.emplace_back(valueSlice);
3393 
3394  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3395  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3396  resultSizes)))
3397  return failure();
3398 
3399  int64_t outputRank = getOutputOperandRank();
3400  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3401  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3402  loc, getOutput(), resultOffsets, resultSizes, strides);
3403  tiledOperands.emplace_back(outputSlice);
3404 
3405  SmallVector<Type> resultTypes;
3406  resultTypes.push_back(tiledOperands[1].getType());
3407  Operation *tiledOp =
3408  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3409 
3410  return TilingResult{
3411  {tiledOp},
3412  SmallVector<Value>(tiledOp->getResults()),
3413  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3414 }
3415 
3416 //===----------------------------------------------------------------------===//
3417 // LinalgDialect
3418 // TODO: Merge with the LinalgDialect block at the bottom
3419 //===----------------------------------------------------------------------===//
3420 
3421 // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3422 static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3423  auto explicitRange = subMap.getResults();
3424  auto defaultRange = fullMap.getResults();
3425  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3426  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3427  llvm::set_union(explicitSet, defaultSet);
3428  return explicitSet == defaultSet;
3429 }
3430 
3431 /// Check if the user defined map is valid broadcast map. Here broadcast
3432 /// indexing maps are defined in context of corresponding default indexing maps
3433 /// for the given Op. This way the check becomes very simple i.e just check the
3434 /// number of result dims.
3435 /// Returns true if the explictMap is broadcasted with respect to the
3436 /// defaultMap.
3437 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3438  return explictMap.getNumResults() < defaultMap.getNumResults();
3439 }
3440 
3441 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3442 /// indexing map for the MatmulOp \p op for each operand specified by \p
3443 /// opIndex.
3444 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3445  unsigned opIndex) {
3446  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3447  SmallVector<AffineMap, 3> defaultIndexingMaps =
3448  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3449 
3450  auto opIndexingMap = opIndexingMaps[opIndex];
3451  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3452  // Check general validity of indexing map results.
3453  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3454  return matmulOp->emitOpError()
3455  << "Unexpected dim expression in map result.";
3456 
3457  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3458  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3459  return matmulOp->emitOpError()
3460  << "Invalid broadcast requested, should be (d2).";
3461  }
3462  return success();
3463  }
3464  return success();
3465 }
3466 
3467 // Check general validity of input indexing map.
3468 static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
3469  AffineMap opIndexingMap,
3470  AffineMap defaultIndexingMap, bool isLHS) {
3471  // Check the result dims are valid.
3472  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3473  return batchMatmulOp->emitOpError()
3474  << "Unexpected result dim expression (outside the set of default "
3475  "result dims).";
3476 
3477  // Check for valid number of result dims of input maps.
3478  if (opIndexingMap.getNumResults() > 3)
3479  return batchMatmulOp->emitOpError()
3480  << "no. of result dim expressions exceeds 3.";
3481 
3482  auto hasValidBatchDim = [](AffineMap map) {
3483  AffineExpr batchDim = map.getResult(0);
3484  return batchDim.isFunctionOfDim(0);
3485  };
3486 
3487  // Check if the requested broadcast is valid.
3488  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3489  if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3490  return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3491  } else if (!hasValidBatchDim(opIndexingMap)) {
3492  return batchMatmulOp->emitOpError()
3493  << "Invalid batch dimension expression.";
3494  }
3495  return success();
3496 }
3497 
3498 /// This function checks if the given AffineMap for the output of a
3499 /// BatchMatmulOp has exactly 3 result dimensions and if the output map result
3500 /// dimensions are valid.
3501 static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
3502  AffineMap opIndexingMap) {
3503  if (opIndexingMap.getNumResults() != 3)
3504  return batchMatmulOp->emitOpError()
3505  << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3506  << ").";
3507 
3508  auto areValidOutputResultDim = [](AffineMap outputMap) {
3509  return outputMap.getResult(0).isFunctionOfDim(0) &&
3510  outputMap.getResult(1).isFunctionOfDim(1) &&
3511  outputMap.getResult(2).isFunctionOfDim(2);
3512  };
3513 
3514  if (!areValidOutputResultDim(opIndexingMap))
3515  return batchMatmulOp->emitOpError()
3516  << "Invalid output map result dimension.";
3517 
3518  return success();
3519 }
3520 
3521 /// Verifies the broadcast and transpose semantic specified by the explicit
3522 /// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
3523 static LogicalResult
3524 verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
3525  unsigned opIndex) {
3526  SmallVector<AffineMap, 3> opIndexingMaps =
3527  batchMatmulOp.getIndexingMapsArray();
3528  SmallVector<AffineMap, 3> defaultIndexingMaps =
3529  batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
3530 
3531  if (opIndexingMaps.size() != 3)
3532  return batchMatmulOp->emitOpError()
3533  << "Indexing_map attribute must have 3 affine maps.";
3534 
3535  auto opIndexingMap = opIndexingMaps[opIndex];
3536  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3537 
3538  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
3539  return failure();
3540 
3541  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3542  opIndex == 0)))
3543  return failure();
3544 
3545  return success();
3546 }
3547 
3548 namespace mlir {
3549 namespace linalg {
3550 
3551 //===----------------------------------------------------------------------===//
3552 // MatMulOp
3553 //===----------------------------------------------------------------------===//
3554 
3555 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3556 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3557  AffineExpr d0, d1, d2;
3558  SmallVector<AffineMap> indexingMaps;
3559  bindDims(context, d0, d1, d2);
3560  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3561  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3562  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3563  return indexingMaps;
3564 }
3565 
3566 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3567  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3568  utils::IteratorType::parallel,
3569  utils::IteratorType::reduction};
3570 }
3571 
3572 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3573 
3574 std::string MatmulOp::getLibraryCallName() {
3575  return generateLibraryCallName(getOperation());
3576 }
3577 
3578 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3579 
3580 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3581 /// the user defined indexing maps are not equal to default map.
3582 bool MatmulOp::hasUserDefinedMaps() {
3583  SmallVector<AffineMap, 3> defaultMaps =
3584  getDefaultIndexingMaps(this->getContext());
3585  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3586  return defaultMaps != explicitMaps;
3587 }
3588 
3589 /// Implements the block region builder for the MatmulOp. This is called by
3590 /// 'fillStructuredOpRegion'.
3591 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3592  ArrayRef<NamedAttribute> attrs) {
3593  assert(3 > 0 && block.getNumArguments() == 3 &&
3594  "MatmulOp regionBuilder expects 3 (>=0) args");
3595  RegionBuilderHelper helper(b, block);
3596  SmallVector<Value> yields;
3597 
3598  TypeFn castVal = TypeFn::cast_signed;
3599  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3600  return attr.getName() == "cast";
3601  });
3602  if (castIter != attrs.end()) {
3603  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3604  castVal = attr.getValue();
3605  }
3606 
3607  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3608  block.getArgument(0));
3609  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3610  block.getArgument(1));
3611  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3612  Value value4 =
3613  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3614  yields.push_back(value4);
3615  helper.yieldOutputs(yields);
3616 }
3617 
3618 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3619 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3620  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3621  AffineExpr exp = bcastMap.getResult(0);
3622  // Invalid map if the common dimension of matmul not found.
3623  return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
3624 }
3625 
3626 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3627  if (parser.parseOptionalKeyword("indexing_maps"))
3628  return ArrayAttr{
3629  nullptr}; // Success in case indexing_maps was not provided.
3630 
3631  ArrayAttr arrayAttr;
3632  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3633  return failure();
3634 
3635  if (llvm::any_of(arrayAttr,
3636  [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3637  return parser.emitError(parser.getCurrentLocation())
3638  << "element of indexing_maps array is not an affine_map";
3639 
3640  return arrayAttr;
3641 }
3642 
3643 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3644  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3645  if (failed(indexingMapsAttr))
3646  return failure();
3647 
3648  if (*indexingMapsAttr == nullptr) {
3649  auto indexingMapAttrs = llvm::map_to_vector(
3650  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3651  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3652  indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3653  }
3654 
3655  result.addAttribute("indexing_maps", *indexingMapsAttr);
3656  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3657  MatmulOp::getRegionBuilder());
3658 }
3659 
3660 void MatmulOp::print(OpAsmPrinter &p) {
3661  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3662  MatmulOp::getDefaultIndexingMaps(getContext()),
3663  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3664  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3665  p << " indexing_maps = [";
3666  llvm::interleaveComma(getIndexingMaps(), p,
3667  [&](Attribute attr) { p.printAttribute(attr); });
3668  p << "]";
3669  }
3670 
3671  SmallVector<StringRef, 3> elidedAttrs = {
3672  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3673  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3674  elidedAttrs);
3675 }
3676 
3677 /// Verify the user defined indexing maps.
3678 LogicalResult MatmulOp::verify() {
3679  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3680  if (!hasUserDefinedMaps())
3681  return success();
3682 
3683  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3684  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3685  return failure();
3686  }
3687  return success();
3688 }
3689 
3690 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3691  return memref::foldMemRefCast(*this);
3692 }
3693 
3694 void MatmulOp::getEffects(
3696  &effects) {
3697  if (hasPureTensorSemantics())
3698  return;
3699  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3700 }
3701 
3702 Speculation::Speculatability MatmulOp::getSpeculatability() {
3703  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3704 }
3705 
3706 //===----------------------------------------------------------------------===//
3707 // ContractOp
3708 //===----------------------------------------------------------------------===//
3709 
3710 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3711  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3712  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3713  // domains are all the same, and each implements a projected permutation.
3714  // Each iteration space dim must occur for at least one operand and either
3715  // takes part in a contraction/reduction or else has parallel iteration type.
3716  // We have that a dim is a contraction/reduction dim if and only if the dim
3717  // occurs for the output operand. We use this fact for fast inference:
3718  // NB: In case we allow dims to occur solely for one input, the above still
3719  // holds: per the einsum semantics, these are reduction dims as well.
3720  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3721  for (auto result : outAffineMap.getResults()) {
3722  auto dimExpr = dyn_cast<AffineDimExpr>(result);
3723  assert(dimExpr && "affine_map is a projected permutation");
3724  dimsInOutput[dimExpr.getPosition()] = true;
3725  }
3726 
3727  SmallVector<utils::IteratorType> iteratorTypes;
3728  for (auto dimOccursInOutput : dimsInOutput)
3729  iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3730  : utils::IteratorType::reduction);
3731 
3732  return iteratorTypes;
3733 }
3734 
3735 unsigned ContractOp::getNumRegionArgs() { return 3; }
3736 
3737 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3738 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3739  ArrayRef<NamedAttribute> attrs) {
3740  assert(block.getNumArguments() == 3 &&
3741  "ContractOp regionBuilder expects 3 args");
3742  RegionBuilderHelper helper(b, block);
3743 
3744  TypeFn castSignedness = TypeFn::cast_signed;
3745  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3746  return attr.getName() == "cast";
3747  });
3748  if (castIter != attrs.end()) {
3749  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3750  castSignedness = attr.getValue();
3751  }
3752 
3753  // TODO: Support fields with operators besides mult & add.
3754  Type outType = block.getArgument(2).getType();
3755  Value lhsAtOutType =
3756  helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3757  Value rhsAtOutType =
3758  helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3759  Value productAtOutType =
3760  helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3761  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3762  productAtOutType);
3763  helper.yieldOutputs({result});
3764 }
3765 
3766 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3767  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3768  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3769  return parser.emitError(parser.getCurrentLocation(),
3770  "expected 'indexing_maps' attribute");
3771  result.addAttribute("indexing_maps", *indexingMapsAttr);
3772 
3773  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3774  regionBuilder);
3775 }
3776 
3778  p << " indexing_maps = [";
3779  llvm::interleaveComma(getIndexingMaps(), p,
3780  [&](Attribute attr) { p.printAttribute(attr); });
3781  p << "]";
3783  p, getOperation(), getInputs(), getOutputs(),
3784  /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3785 }
3786 
3787 LogicalResult ContractOp::verify() {
3788  int iterationSpaceDims = -1;
3789  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3790  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3791  // access an input operand (so occurrence count can be at most 2) and
3792  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3793  SmallVector<size_t> inOccurrences;
3794  SmallVector<size_t> outOccurrences;
3795 
3796  // A helper so that for each operand's affine_map and type we check that ...
3797  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3798  bool isInput) -> LogicalResult {
3799  // ... the affine_map is a projected permutation;
3800  if (!affineMap.isProjectedPermutation())
3801  return emitError("provided affine_map is not a projected permutation");
3802 
3803  // ... the rank of the affine_map's results and corresponding type match;
3804  if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3805  if (affineMap.getNumResults() != shapedType.getRank())
3806  return emitError("ranks of shaped operand and results of corresponding "
3807  "affine_map differ");
3808  } else if (affineMap.getNumResults() != 0) {
3809  return emitError("affine_map specifies shaped access while operand has "
3810  "non-shaped type");
3811  }
3812 
3813  // ... the rank of the affine_map's domain is the same as those seen prior;
3814  if (iterationSpaceDims == -1) {
3815  iterationSpaceDims = affineMap.getNumDims();
3816  inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3817  outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3818  } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3819  return emitError("iteration spaces of provided affine_maps differ");
3820  }
3821 
3822  // ... update counts of dims used to access either an input or the output.
3823  for (AffineExpr affineExpr : affineMap.getResults()) {
3824  auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3825  if (!affineDimExpr)
3826  llvm_unreachable("affine_map is a projected permutation");
3827 
3828  if (isInput)
3829  inOccurrences[affineDimExpr.getPosition()] += 1;
3830  else
3831  outOccurrences[affineDimExpr.getPosition()] += 1;
3832  }
3833 
3834  return success();
3835  };
3836 
3837  for (auto &&[affineMap, operandType, isInput] :
3838  llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3839  SmallVector<bool>{true, true, false})) {
3840  if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3841  return failure(); // NB: checkAffineMapAndType will emit relevant error.
3842  }
3843 
3844  bool hasContractingDim = false;
3845  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3846  size_t inOccCount = inOccurrences[dimIndex];
3847  size_t outOccCount = outOccurrences[dimIndex];
3848 
3849  // We have a contracting dim if and only if ...
3850  hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3851 
3852  if (inOccCount == 0 && outOccCount == 0)
3853  return emitError() << "iteration space dim at index " << dimIndex
3854  << " not used to access any operand";
3855 
3856  // NB: We disallow a dim which occurs for only one input operand and not
3857  // for the output. In terms of einsum semantics such dims have a
3858  // sensible meaning - namely an additional reduction per each such dim.
3859  // By contrast, the ContractionOpInterface does not know about this
3860  // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
3861  // while vector.contract's verifier accepts dims of this kind many of
3862  // its lowerings give up on encountering these dims.
3863  // TODO: Remove following once we have comprehensive support for input-only
3864  // reduction dims, at both the linalg- and vector-dialect levels.
3865  if (inOccCount == 1 && outOccCount != 1)
3866  return emitError()
3867  << "iteration space dim at index " << dimIndex
3868  << " is neither a contracting dim nor of parallel iteration type";
3869  }
3870 
3871  if (!hasContractingDim)
3872  return emitError("'indexing_maps' do not specify a contracting dimension");
3873 
3874  return success();
3875 }
3876 
3877 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3878  return memref::foldMemRefCast(*this);
3879 }
3880 
3881 void ContractOp::getEffects(
3883  &effects) {
3884  if (hasPureTensorSemantics())
3885  return;
3886  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3887 }
3888 
3889 Speculation::Speculatability ContractOp::getSpeculatability() {
3890  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3891 }
3892 
3893 //===----------------------------------------------------------------------===//
3894 // Implementation of BatchMatmulOp
3895 //===----------------------------------------------------------------------===//
3897 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3898  AffineExpr d0, d1, d2, d3;
3899  SmallVector<AffineMap> indexingMaps;
3900  bindDims(context, d0, d1, d2, d3);
3901  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
3902  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
3903  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
3904  return indexingMaps;
3905 }
3906 
3907 SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
3909  utils::IteratorType::parallel, utils::IteratorType::parallel,
3910  utils::IteratorType::parallel, utils::IteratorType::reduction};
3911 }
3912 
3913 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
3914 
3915 std::string BatchMatmulOp::getLibraryCallName() {
3916  return generateLibraryCallName(getOperation());
3917 }
3918 
3919 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3920 /// the user defined indexing maps are not equal to default map.
3921 bool BatchMatmulOp::hasUserDefinedMaps() {
3922  SmallVector<AffineMap, 3> defaultMaps =
3923  getDefaultIndexingMaps(this->getContext());
3924  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3925  return defaultMaps != explicitMaps;
3926 }
3927 
3928 /// Returns true if the given broadcast map bcastMap is valid for this op.
3929 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
3930  assert(bcastMap.getNumResults() < 3 &&
3931  "Expected less than 3 result dim expr.");
3932  bool isValid = false;
3933  enum Indices { batchPos, mPos, nPos, kPos };
3934  if (bcastMap.getNumResults() == 1) {
3935  AffineExpr exp = bcastMap.getResult(0);
3936  isValid = exp.isFunctionOfDim(kPos);
3937  } else if (bcastMap.getNumResults() == 2) {
3938  AffineExpr exp0 = bcastMap.getResult(0);
3939  AffineExpr exp1 = bcastMap.getResult(1);
3940  isValid = isLHS
3941  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
3942  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
3943  }
3944  return isValid;
3945 }
3946 
3947 void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3948  ArrayRef<NamedAttribute> attrs) {
3949  assert(block.getNumArguments() == 3 &&
3950  "BatchMatmulOp regionBuilder expects 3 (>=0) args");
3951  RegionBuilderHelper helper(b, block);
3952  SmallVector<Value> yields;
3953 
3954  auto toType = block.getArgument(2).getType();
3955  Value castValA =
3956  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
3957  Value castValB =
3958  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
3959  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3960  Value addVal =
3961  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
3962  yields.push_back(addVal);
3963  helper.yieldOutputs(yields);
3964 }
3965 
3966 ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3967  SmallVector<Attribute, 3> indexingMapsAttr;
3968  Attribute mapAttr;
3969  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
3970  if (parser.parseEqual())
3971  return failure();
3972 
3973  if (parser.parseLSquare())
3974  return failure();
3975 
3976  do {
3977  if (parser.parseAttribute(mapAttr))
3978  return failure();
3979  if (!isa<AffineMapAttr>(mapAttr)) {
3980  return parser.emitError(parser.getCurrentLocation(),
3981  "expected affine map attribute");
3982  }
3983  indexingMapsAttr.push_back(mapAttr);
3984 
3985  if (parser.parseOptionalComma())
3986  break;
3987  } while (true);
3988 
3989  if (parser.parseRSquare())
3990  return failure();
3991  }
3992  // Initialize indexingMaps, if not supplied explicitly.
3993  if (indexingMapsAttr.empty()) {
3994  indexingMapsAttr = llvm::map_to_vector(
3995  BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
3996  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3997  }
3998  result.addAttribute("indexing_maps",
3999  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4000 
4001  return ::parseNamedStructuredOp(parser, result,
4002  BatchMatmulOp::getNumRegionArgs(),
4003  BatchMatmulOp::getRegionBuilder());
4004 }
4005 
4007  SmallVector<StringRef, 3> elidedAttrs = {
4008  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4009  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4010  elidedAttrs);
4011 
4012  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
4013  BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4014  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4015  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
4016  p << " indexing_maps = [";
4017  llvm::interleaveComma(getIndexingMaps(), p,
4018  [&](Attribute attr) { p.printAttribute(attr); });
4019  p << "]";
4020  }
4021 }
4022 
4023 /// Verify the user defined indexing maps.
4024 LogicalResult BatchMatmulOp::verify() {
4025  // Verification of pure batch_matmul is handled by
4026  // verifyStructuredOpInterface().
4027  if (!hasUserDefinedMaps())
4028  return success();
4029 
4030  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4031  if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
4032  return failure();
4033  }
4034  return success();
4035 }
4036 
4037 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4039  return memref::foldMemRefCast(*this);
4040 }
4041 
4042 void BatchMatmulOp::getEffects(
4044  &effects) {
4045  if (hasPureTensorSemantics())
4046  return;
4047  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4048 }
4049 
4050 Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4051  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4052 }
4053 
4054 //===----------------------------------------------------------------------===//
4055 // PackOp/UnPackOp Common
4056 //===----------------------------------------------------------------------===//
4057 // Given the (potentially) updated packed type, `newPackedTy`, generates an
4058 // updated mixed-tile-sizes attribute. A tile size is updated only
4059 // when:
4060 // * a dim from newPackedTy is static, and
4061 // * the corresponding size from mixedTiles is still dynamic.
4062 // Otherwise, the original tile size is preserved.
4063 // Note - packed-type-dim and mixed-tile-size should always match!
4066  SmallVector<OpFoldResult> mixedTiles) {
4067  SmallVector<OpFoldResult> newMixedTileSizes;
4068  for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4069  .getShape()
4070  .take_back(mixedTiles.size()),
4071  mixedTiles)) {
4072  int64_t shape = std::get<0>(it);
4073  if (shape == ShapedType::kDynamic) {
4074  newMixedTileSizes.push_back(std::get<1>(it));
4075  continue;
4076  }
4077 
4078  // If the current result dim is static, update the dynamic mixed-size
4079  // (provided the original value is dynamic).
4080  OpFoldResult tile = std::get<1>(it);
4081  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4082  // Already a constant
4083  newMixedTileSizes.push_back(tile);
4084  } else {
4085  assert(getConstantIntValue(tile).value() == shape &&
4086  "tile size and dim size don't match!");
4087  newMixedTileSizes.push_back(
4088  (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4089  }
4090  }
4091 
4092  return newMixedTileSizes;
4093 }
4094 
4095 template <typename OpTy>
4096 static LogicalResult
4098  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4099  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4100  "applies to only pack or unpack operations");
4101  int64_t destRank = op.getDestRank();
4102  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
4103  reifiedReturnShapes[0] =
4104  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
4105  return success();
4106 }
4107 
4108 template <typename OpTy>
4110  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4111  "applies to only pack or unpack operations");
4112  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
4113  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
4114  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
4115  assert(tiles.size() == dimsToTile.size() &&
4116  "tiles must match indices of dimension to block");
4117  // bind the dimension `i` with the tile factor.
4118  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4119  dimAndTileMapping[dimsToTile[i]] = tiles[i];
4120  return dimAndTileMapping;
4121 }
4122 
4123 template <typename OpTy>
4125  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4126  "applies to only pack or unpack operations");
4127  Builder builder(op);
4128  SmallVector<OpFoldResult> mixedInnerTiles;
4129  unsigned dynamicValIndex = 0;
4130  for (int64_t staticTile : op.getStaticInnerTiles()) {
4131  if (!ShapedType::isDynamic(staticTile))
4132  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
4133  else
4134  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4135  }
4136  return mixedInnerTiles;
4137 }
4138 
4139 template <typename OpTy>
4141  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4142  "applies to only pack or unpack operations");
4143  SmallVector<Value> dynamicTiles;
4144  SmallVector<int64_t> staticTiles;
4145  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
4146  return staticTiles;
4147 }
4148 
4149 /// Returns true if `dimsPos` is invalid. It is invalid when:
4150 /// a) It contains duplicate.
4151 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
4152 /// c) The number of elements in `dimsPos` is > than `rank`.
4154  size_t rank) {
4155  size_t dimsPosSize = dimsPos.size();
4156  if (dimsPosSize > rank)
4157  return true;
4158  DenseSet<int64_t> uniqued;
4159  for (int64_t dim : dimsPos)
4160  uniqued.insert(dim);
4161  if (dimsPosSize != uniqued.size())
4162  return true;
4163  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4164  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4165  });
4166 }
4167 
4168 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
4169 /// of the `limitShape`.
4170 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4171  ArrayRef<int64_t> limitShape) {
4172  assert(
4173  sourceShape.size() == limitShape.size() &&
4174  "expected source shape rank, and limit of the shape to have same rank");
4175  return llvm::all_of(
4176  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4177  int64_t sourceExtent = std::get<0>(it);
4178  int64_t limit = std::get<1>(it);
4179  return ShapedType::isDynamic(sourceExtent) ||
4180  ShapedType::isDynamic(limit) || sourceExtent <= limit;
4181  });
4182 }
4183 
4184 template <typename OpTy>
4185 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4186  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4187  "applies to only pack or unpack operations");
4188  Operation *op = packOrUnPack.getOperation();
4189 
4190  // Return true if we have a zero-value tile.
4191  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4192  return llvm::any_of(
4193  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
4194  };
4195 
4196  // Verify tiles. Do not allow zero tiles.
4197  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
4198  if (hasZeros(mixedTiles))
4199  return op->emitError("invalid zero tile factor");
4200 
4201  // Verify inner_dims_pos and outer_dims_perm.
4202  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4203  ? packOrUnPack.getSourceType()
4204  : packOrUnPack.getDestType();
4205  size_t unpackedRank = unpackedType.getRank();
4206  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
4207  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4209  return op->emitError("invalid inner_dims_pos vector");
4210  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
4211  return op->emitError("invalid outer_dims_perm vector");
4212  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4213  return op->emitError("outer_dims_perm must be a permutation or empty");
4214 
4215  // Tiling factors must be less than or equal to the input rank for pack (or
4216  // output rank for unpack), and must match the number of `inner_dims_pos`.
4217  if (mixedTiles.size() > unpackedRank) {
4218  return op->emitError("tiling factors must be less than or equal to the "
4219  "input rank for pack or output rank for unpack");
4220  }
4221  if (mixedTiles.size() != innerDimsPos.size()) {
4222  return op->emitError(
4223  "tiling factors must equal the number of dimensions to tile");
4224  }
4225 
4226  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4227  ? packOrUnPack.getDestType()
4228  : packOrUnPack.getSourceType();
4229  size_t packedRank = packedType.getRank();
4230  // Require output rank to match input rank + number of blocking factors.
4231  size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4232  if (expectedPackedRank != packedRank) {
4233  return op->emitError(
4234  "packed rank != (unpacked rank + num tiling factors), got ")
4235  << packedRank << " != " << expectedPackedRank;
4236  }
4237 
4238  // Verify result shape is greater than the minimum expected
4239  // by the pack operation, and that the output shape
4240  // represents full tiles.
4241  RankedTensorType expectedPackedType = PackOp::inferPackedType(
4242  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4243  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4244  return op->emitError("the shape of output is not large enough to hold the "
4245  "packed data. Expected at least ")
4246  << expectedPackedType << ", got " << packedType;
4247  }
4248  if (!llvm::all_of(
4249  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4250  mixedTiles),
4251  [](std::tuple<int64_t, OpFoldResult> it) {
4252  int64_t shape = std::get<0>(it);
4253  if (Attribute attr =
4254  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4255  IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4256  int64_t staticTileSize = intAttr.getValue().getSExtValue();
4257  return shape == staticTileSize;
4258  }
4259  return ShapedType::isDynamic(shape);
4260  })) {
4261  return op->emitError("mismatch in inner tile sizes specified and shaped of "
4262  "tiled dimension in the packed type");
4263  }
4264  return success();
4265 }
4266 
4267 namespace {
4268 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4269 /// various permutations to the op.
4270 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4271 // these. These may or may not become true foldings / canonicalizations
4272 // depending on how aggressive we want to be in automatically folding
4273 // transposes.
4274 struct PackOrUnPackTransposeResult {
4278 };
4279 } // namespace
4280 
4281 template <typename OpTy>
4282 static PackOrUnPackTransposeResult
4284  ArrayRef<int64_t> innerPermutation,
4285  ArrayRef<int64_t> outerPermutation) {
4286  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4287  "applies to only pack or unpack operations");
4288  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4289  "some permutation must be non-empty");
4290  PackOrUnPackTransposeResult metadata;
4291  metadata.innerDimsPos =
4292  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4293  metadata.innerTiles =
4294  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4295  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4296  ? packOrUnPackOp.getSourceRank()
4297  : packOrUnPackOp.getDestRank();
4298  metadata.outerDimsPerm =
4299  packOrUnPackOp.getOuterDimsPerm().empty()
4300  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4301  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4302  if (!innerPermutation.empty()) {
4303  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4304  isPermutationVector(innerPermutation) &&
4305  "invalid inner permutation");
4306  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4307  applyPermutationToVector(metadata.innerTiles, innerPermutation);
4308  }
4309  if (!outerPermutation.empty()) {
4310  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4311  isPermutationVector(outerPermutation) &&
4312  "invalid outer permutation");
4313  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4314  }
4315  return metadata;
4316 }
4317 
4318 //===----------------------------------------------------------------------===//
4319 // PackOp
4320 //===----------------------------------------------------------------------===//
4321 
4322 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4323  setNameFn(getResult(), "pack");
4324 }
4325 
4326 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4329  std::optional<Value> paddingValue,
4331  assert(innerDimsPos.size() == innerTiles.size() &&
4332  "number of tile sizes specified must match the specified number of "
4333  "original dimensions to be tiled");
4334  SmallVector<int64_t> staticTileSizes;
4335  SmallVector<Value> dynamicTileSizes;
4336  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4337  build(builder, state, dest.getType(), source, dest,
4338  paddingValue ? *paddingValue : nullptr,
4339  outerDimsPerm.empty() ? nullptr
4341  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4342  builder.getDenseI64ArrayAttr(staticTileSizes));
4343 }
4344 
4345 LogicalResult
4347  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4348  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4349 }
4350 
4351 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4352  return getDimAndTileMappingImpl(*this);
4353 }
4354 
4355 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4356  return getMixedTilesImpl(*this);
4357 }
4358 
4359 SmallVector<int64_t> PackOp::getStaticTiles() {
4360  return getStaticTilesImpl(*this);
4361 }
4362 
4363 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4364  ShapedType inputType = getSourceType();
4365  int64_t inputRank = inputType.getRank();
4366  return getDestType().getShape().take_front(inputRank);
4367 }
4368 
4369 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4370  auto innerDimsPos = getInnerDimsPos();
4371  auto packedShape = getDestType().getShape();
4373 
4374  for (auto index : innerDimsPos)
4375  res.push_back(packedShape[index]);
4376 
4377  return res;
4378 }
4379 
4380 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4382  ArrayRef<int64_t> outputShape,
4385  SmallVector<int64_t> outputTileSizes(
4386  outputShape.take_front(inputShape.size()));
4387  if (!outerDimsPerm.empty()) {
4388  assert(outerDimsPerm.size() == outputTileSizes.size() &&
4389  "expected output and outer_dims_perm to have same size");
4390  applyPermutationToVector(outputTileSizes,
4392  }
4393  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4394  if (ShapedType::isDynamic(inputShape[pos]))
4395  continue;
4396  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4397 
4398  if (!constantTile) {
4399  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4400  (inputShape[pos] % outputTileSizes[pos] != 0))
4401  return true;
4402  } else if (inputShape[pos] % (*constantTile) != 0) {
4403  return true;
4404  }
4405  }
4406  return false;
4407 }
4408 
4409 LogicalResult PackOp::verify() {
4410  if (failed(commonVerifierPackAndUnPackOp(*this)))
4411  return failure();
4412 
4413  // Verify padding value, and bail out if the tile does not divide the
4414  // dimension fully. In the case of dynamic tile factors or dimensions, having
4415  // a partial tile is undefined behavior.
4416  auto paddingValue = getPaddingValue();
4417  if (paddingValue &&
4418  paddingValue.getType() != getSourceType().getElementType()) {
4419  return emitOpError("expected padding_value has ")
4420  << getSourceType().getElementType()
4421  << " but got: " << paddingValue.getType();
4422  }
4423 
4424  if (!paddingValue &&
4425  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4426  getDestType().getShape(), getOuterDimsPerm(),
4427  getMixedTiles())) {
4428  return emitOpError(
4429  "invalid tile factor or output size provided. Only full tiles are "
4430  "supported when padding_value is not set");
4431  }
4432  return success();
4433 }
4434 
4435 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4436 /// Value's to kDynamic, even if they are arith.constant values.
4437 static SmallVector<int64_t>
4439  SmallVector<int64_t> result;
4440  for (auto o : ofrs) {
4441  // Have to do this first, as getConstantIntValue special-cases constants.
4442  if (llvm::dyn_cast_if_present<Value>(o))
4443  result.push_back(ShapedType::kDynamic);
4444  else
4445  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4446  }
4447  return result;
4448 }
4449 
4450 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4451 /// the packed type. Having a shared helper helps implement these two methods in
4452 /// a way that ensures that they agree on which dimensions are dynamic.
4454  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4456  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4457  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4458  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4459  continue;
4460  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4461  resultShape[tiledDim.value()] = ShapedType::kDynamic;
4462  continue;
4463  }
4464  resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4465  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4466  }
4467 
4468  // Swap tile loops if outer_dims_perm is available.
4469  if (!outerDimsPerm.empty())
4471 
4472  // Append the inner tile dimensions.
4473  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4474  return resultShape;
4475 }
4476 
4477 SmallVector<OpFoldResult> PackOp::getResultShape(
4478  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4481  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4482 
4483  AffineExpr s0, s1;
4484  bindSymbols(builder.getContext(), s0, s1);
4485  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4486  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4487  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4488  builder, loc, ceilDivExpr,
4489  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4490  }
4491  if (!outerDimsPerm.empty())
4493  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4494 
4495  SmallVector<int64_t> resultTypeShape =
4497  asShapeWithAnyValueAsDynamic(innerTileSizes),
4499 
4500  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4501  // result type shape says it's a dynamic dim. This is needed as callers may
4502  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4503  // dynamic dims returned by that.
4504  for (unsigned i = 0; i < resultDims.size(); ++i) {
4505  if (!ShapedType::isDynamic(resultTypeShape[i]))
4506  continue;
4507  resultDims[i] =
4508  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4509  }
4510 
4511  return resultDims;
4512 }
4513 
4514 /// Get the expected packed type based on source type, tile factors, position of
4515 /// the inner tiles and permutation of the outer tiled loop.
4516 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4517  ArrayRef<int64_t> innerTileSizes,
4521  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4522  return RankedTensorType::get(resultShape, sourceType.getElementType());
4523 }
4524 
4525 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4526  ArrayRef<OpFoldResult> innerTileSizes,
4529  AffineExpr dim0, dim1;
4530  bindDims(b.getContext(), dim0, dim1);
4531  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4532  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4533  {v1, v2});
4534  };
4535 
4536  SmallVector<OpFoldResult> mixedSizes;
4537  for (auto [index, value] : llvm::enumerate(
4538  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4539  if (ShapedType::isDynamic(value))
4540  mixedSizes.push_back(
4541  b.create<tensor::DimOp>(loc, source, index).getResult());
4542  else
4543  mixedSizes.push_back(b.getIndexAttr(value));
4544  }
4545  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4546  int64_t dimPos = std::get<0>(it);
4547  OpFoldResult tileSize = std::get<1>(it);
4548  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4549  }
4550  if (!outerDimsPerm.empty())
4551  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4552 
4553  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4554  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4555  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4556 }
4557 
4558 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4559  ArrayRef<int64_t> innerPermutation,
4560  ArrayRef<int64_t> outerPermutation) {
4561  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4562  *this, innerPermutation, outerPermutation);
4563  Value transposedDest =
4564  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4565  metadata.innerDimsPos, metadata.outerDimsPerm);
4566  return b.create<PackOp>(loc, getSource(), transposedDest,
4567  metadata.innerDimsPos, metadata.innerTiles,
4568  getPaddingValue(), metadata.outerDimsPerm);
4569 }
4570 
4571 /// Returns true if the tiles and the tiled dims are constant.
4572 template <typename OpTy>
4574  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4575  "applies to only pack or unpack operations");
4576  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4577  ? op.getDestType()
4578  : op.getSourceType();
4579  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4580  for (auto [dimDest, tile] : llvm::zip(
4581  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4582  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4583  if (!constTileSize || ShapedType::isDynamic(dimDest))
4584  return false;
4585  }
4586  return true;
4587 }
4588 
4589 Speculation::Speculatability PackOp::getSpeculatability() {
4590  if (getPaddingValue())
4592 
4593  // The verifier rejects already operations if we can statically prove that the
4594  // sizes of the tiles do not divide perfectly the dimension; thus, check only
4595  // to have constant tiles and tiled inner dimensions.
4596  if (!areTilesAndTiledDimsAllConstant(*this))
4598 
4600 }
4601 
4602 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4603 // dimensions for pack and unpack.
4604 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4605  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4606  return false;
4607  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4608  return true;
4609  // Outer dims permutation is optional.
4610  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4611  // identity permutation.
4612  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4613  isIdentityPermutation(unPackOp.getOuterDimsPerm());
4614 }
4615 
4616 // Return true if pack and unpack have the same tiles.
4617 // Same SSA values or same integer constants.
4618 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4619  auto packTiles = packOp.getMixedTiles();
4620  auto unPackTiles = unPackOp.getMixedTiles();
4621  if (packTiles.size() != unPackTiles.size())
4622  return false;
4623  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4624  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4625  return false;
4626  }
4627  return true;
4628 }
4629 
4630 /// Returns true if the pack op does not need a padding value.
4631 static bool paddingIsNotNeeded(PackOp op) {
4632  auto srcType = op.getSourceType();
4633  if (llvm::any_of(op.getInnerDimsPos(),
4634  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4635  return false;
4636  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4637  return false;
4638  return !PackOp::requirePaddingValue(
4639  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4640  op.getOuterDimsPerm(), op.getMixedTiles());
4641 }
4642 
4643 /// Returns true if the `srcShape` or `destShape` is different from the one in
4644 /// `packOp` and populates each with the inferred static shape.
4645 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4646  SmallVectorImpl<int64_t> &destShape) {
4647  bool changeNeeded = false;
4648  srcShape.assign(packOp.getSourceType().getShape().begin(),
4649  packOp.getSourceType().getShape().end());
4650  destShape.assign(packOp.getDestType().getShape().begin(),
4651  packOp.getDestType().getShape().end());
4652  llvm::SmallSetVector<int64_t, 4> innerDims;
4653  innerDims.insert(packOp.getInnerDimsPos().begin(),
4654  packOp.getInnerDimsPos().end());
4655  SmallVector<int64_t> inverseOuterDimsPerm;
4656  if (!packOp.getOuterDimsPerm().empty())
4657  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4658  int srcRank = packOp.getSourceRank();
4659  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4660  if (innerDims.contains(i))
4661  continue;
4662  int64_t srcPos = i;
4663  int64_t destPos = i;
4664  if (!inverseOuterDimsPerm.empty())
4665  destPos = inverseOuterDimsPerm[srcPos];
4666  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4667  ShapedType::isDynamic(destShape[destPos])) {
4668  continue;
4669  }
4670  int64_t size = srcShape[srcPos];
4671  if (ShapedType::isDynamic(size))
4672  size = destShape[destPos];
4673  srcShape[srcPos] = size;
4674  destShape[destPos] = size;
4675  changeNeeded = true;
4676  }
4677  return changeNeeded;
4678 }
4679 
4680 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4681  // Fold an pack(unpack(x)) to x.
4682  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4683  if (unPackOp.getSourceType() != packOp.getDestType())
4684  return failure();
4685  if (packOp.getPaddingValue() ||
4686  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4687  !haveSameTiles(packOp, unPackOp))
4688  return failure();
4689  rewriter.replaceOp(packOp, unPackOp.getSource());
4690  return success();
4691  }
4692 
4693  // Fold optional PaddingValue operand away if padding is not needed.
4694  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4695  rewriter.startOpModification(packOp);
4696  packOp.getPaddingValueMutable().clear();
4697  rewriter.finalizeOpModification(packOp);
4698  return success();
4699  }
4700 
4701  // Insert tensor.cast ops if static shape inference is available..
4702  SmallVector<int64_t> srcShape, destShape;
4703  if (inferStaticShape(packOp, srcShape, destShape)) {
4704  Location loc = packOp.getLoc();
4705  Value source = packOp.getSource();
4706  if (srcShape != packOp.getSourceType().getShape()) {
4707  auto newSrcType = packOp.getSourceType().clone(srcShape);
4708  source =
4709  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4710  }
4711  Value dest = packOp.getDest();
4712  RankedTensorType originalResultType = packOp.getDestType();
4713  bool needUpdateDestType = (destShape != originalResultType.getShape());
4714  if (needUpdateDestType) {
4715  auto newDestType = packOp.getDestType().clone(destShape);
4716  dest =
4717  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4718  }
4719  rewriter.modifyOpInPlace(packOp, [&] {
4720  packOp.getSourceMutable().assign(source);
4721  packOp.getDestMutable().assign(dest);
4722  packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4723  });
4724  // Insert a cast if needed
4725  if (needUpdateDestType) {
4726  rewriter.setInsertionPointAfter(packOp);
4727  auto castOp =
4728  rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4729  rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
4730  }
4731  return success();
4732  }
4733 
4734  return failure();
4735 }
4736 
4737 template <typename PackOrUnpackOp>
4738 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4739  RankedTensorType packedTensorType) {
4740  static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
4741  std::is_same<PackOrUnpackOp, UnPackOp>::value,
4742  "Function meant for pack/unpack");
4743  // This is a pad if packing only adds ones and we don't transpose dimensions.
4744 
4745  // Check that we are not transposing any dimensions.
4746  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4747  int64_t numPackedDims = innerDimsPos.size();
4748  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4749  if (orderedDims != innerDimsPos) {
4750  // Dimensions don't happen in order.
4751  return false;
4752  }
4753 
4754  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4755  int64_t packedRank = packedTensorType.getRank();
4756  // At this point we know that we are taking numPackedDims outer
4757  // dimensions and pushing them all the way as the inner most dimensions.
4758  // What's left on the outer most dimensions is, in this order:
4759  // - the factor of the packed dimensions, then
4760  // - the untouched dimensions
4761  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4762  // if all the dimensions that bubble outerward are ones.
4763  // Therefore check that all the dimensions but the numPackedDims inner most
4764  // ones are ones.
4765  return llvm::all_of(
4766  llvm::seq<int64_t>(0, packedRank - numPackedDims),
4767  [&packedShape](int64_t i) { return packedShape[i] == 1; });
4768 }
4769 
4770 bool PackOp::isLikePad() {
4771  auto packedTensorType =
4772  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4773  return isLikePadUnPad(*this, packedTensorType);
4774 }
4775 
4776 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
4777  std::optional<Attribute> paddingValue;
4778  if (auto pad = adaptor.getPaddingValue())
4779  paddingValue = pad;
4780  if (OpFoldResult reshapedSource = reshapeConstantSource(
4781  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4782  getDestType(), paddingValue))
4783  return reshapedSource;
4784  return {};
4785 }
4786 
4787 /// Folds a tensor.cast op into a consuming PackOp op if the
4788 /// `tensor.cast` has source that is more static than the consuming op.
4789 ///
4790 /// Example:
4791 /// ```mlir
4792 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4793 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4794 /// ```
4795 ///
4796 /// folds into:
4797 ///
4798 /// ```mlir
4799 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4800 /// ```
4801 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4803 
4804  LogicalResult matchAndRewrite(PackOp op,
4805  PatternRewriter &rewriter) const override {
4807  return failure();
4808 
4809  SmallVector<Type> newResultTypes(op->getResultTypes());
4810  SmallVector<Value> newOperands =
4812 
4813  // Get the updated mixed-tile-sizes attribute.
4814  SmallVector<OpFoldResult> newMixedTileSizes =
4815  getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
4816 
4817  // Clone op.
4818  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4819  // this point. However, in practice, we use them for things that we'd like
4820  // to preserve. Implement a better abstraction.
4821  PackOp newOp = rewriter.create<PackOp>(
4822  op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
4823  newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
4824  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
4825 
4826  // Replace op.
4827  Value oldResult = op.getResult();
4828  Value newResult = newOp.getResult();
4829  Value replacement = (newResult.getType() != oldResult.getType())
4830  ? rewriter.create<tensor::CastOp>(
4831  op->getLoc(), oldResult.getType(), newResult)
4832  : newResult;
4833 
4834  rewriter.replaceOp(op, {replacement});
4835 
4836  return success();
4837  }
4838 };
4839 
4840 //===----------------------------------------------------------------------===//
4841 // UnPackOp
4842 //===----------------------------------------------------------------------===//
4843 
4844 void UnPackOp::getAsmResultNames(
4845  function_ref<void(Value, StringRef)> setNameFn) {
4846  setNameFn(getResult(), "unpack");
4847 }
4848 
4849 LogicalResult
4851  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4852  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4853 }
4854 
4855 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
4856  return getDimAndTileMappingImpl(*this);
4857 }
4858 
4859 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
4860  return getMixedTilesImpl(*this);
4861 }
4862 
4863 SmallVector<int64_t> UnPackOp::getStaticTiles() {
4864  return getStaticTilesImpl(*this);
4865 }
4866 
4867 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
4868  ShapedType destType = getDestType();
4869  int64_t destRank = destType.getRank();
4870  return getSourceType().getShape().take_front(destRank);
4871 }
4872 
4873 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
4874  auto innerDimsPos = getInnerDimsPos();
4875  auto packedShape = getSourceType().getShape();
4877 
4878  for (auto index : innerDimsPos)
4879  res.push_back(packedShape[index]);
4880 
4881  return res;
4882 }
4883 
4884 LogicalResult UnPackOp::verify() {
4885  return commonVerifierPackAndUnPackOp(*this);
4886 }
4887 
4888 Speculation::Speculatability UnPackOp::getSpeculatability() {
4889  // See PackOp::getSpeculatability.
4890  if (!areTilesAndTiledDimsAllConstant(*this))
4892 
4894 }
4895 
4896 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
4900  assert(innerDimsPos.size() == innerTiles.size() &&
4901  "number of tile sizes specified must match the specified number of "
4902  "original dimensions to be tiled");
4903  SmallVector<int64_t> staticTileSizes;
4904  SmallVector<Value> dynamicTileSizes;
4905  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4906  build(builder, state, dest.getType(), source, dest,
4907  outerDimsPerm.empty() ? nullptr
4909  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4910  builder.getDenseI64ArrayAttr(staticTileSizes));
4911 }
4912 
4913 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
4914  Value source,
4915  ArrayRef<OpFoldResult> innerTileSizes,
4918  AffineExpr sym0, sym1;
4919  bindSymbols(b.getContext(), sym0, sym1);
4920  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4921  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
4922  };
4923 
4924  SmallVector<OpFoldResult> mixedSizes;
4925  auto srcType = llvm::cast<RankedTensorType>(source.getType());
4926  for (auto i :
4927  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4928  if (srcType.isDynamicDim(i))
4929  mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult());
4930  else
4931  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
4932  }
4933  if (!outerDimsPerm.empty()) {
4934  applyPermutationToVector<OpFoldResult>(
4935  mixedSizes, invertPermutationVector(outerDimsPerm));
4936  }
4937 
4938  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4939  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4940 
4941  auto elemType = srcType.getElementType();
4942  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4943 }
4944 
4945 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
4946  Value transposedSource,
4947  ArrayRef<int64_t> innerPermutation,
4948  ArrayRef<int64_t> outerPermutation) {
4949  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4950  *this, innerPermutation, outerPermutation);
4951  return b.create<UnPackOp>(loc, transposedSource, getDest(),
4952  metadata.innerDimsPos, metadata.innerTiles,
4953  metadata.outerDimsPerm);
4954 }
4955 
4956 /// Returns true if the `srcShape` or `destShape` is different from the one in
4957 /// `op` and populates each with the inferred static shape.
4958 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
4959  SmallVectorImpl<int64_t> &destShape) {
4960  bool changeNeeded = false;
4961  srcShape.assign(op.getSourceType().getShape().begin(),
4962  op.getSourceType().getShape().end());
4963  destShape.assign(op.getDestType().getShape().begin(),
4964  op.getDestType().getShape().end());
4965  llvm::SmallSetVector<int64_t, 4> innerDims;
4966  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4967  SmallVector<int64_t> inverseOuterDimsPerm;
4968  if (!op.getOuterDimsPerm().empty())
4969  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
4970  int destRank = op.getDestRank();
4971  for (auto i : llvm::seq<int64_t>(0, destRank)) {
4972  if (innerDims.contains(i))
4973  continue;
4974  int64_t srcPos = i;
4975  int64_t destPos = i;
4976  if (!inverseOuterDimsPerm.empty())
4977  srcPos = inverseOuterDimsPerm[destPos];
4978  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4979  ShapedType::isDynamic(destShape[destPos])) {
4980  continue;
4981  }
4982  int64_t size = srcShape[srcPos];
4983  if (ShapedType::isDynamic(size))
4984  size = destShape[destPos];
4985  srcShape[srcPos] = size;
4986  destShape[destPos] = size;
4987  changeNeeded = true;
4988  }
4989  return changeNeeded;
4990 }
4991 
4992 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4993  PatternRewriter &rewriter) {
4994  /// unpack(pack(x)) -> x
4995  if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
4996  if (packOp.getSourceType() != unPackOp.getDestType())
4997  return failure();
4998  if (packOp.getPaddingValue() ||
4999  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5000  !haveSameTiles(packOp, unPackOp))
5001  return failure();
5002  rewriter.replaceOp(unPackOp, packOp.getSource());
5003  return success();
5004  }
5005  /// unpack(destinationStyleOp(x)) -> unpack(x)
5006  if (auto dstStyleOp =
5007  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5008  auto destValue = cast<OpResult>(unPackOp.getDest());
5009  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5010  rewriter.modifyOpInPlace(unPackOp,
5011  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5012  return success();
5013  }
5014 
5015  // Insert tensor.cast ops if static shape inference is available..
5016  SmallVector<int64_t> srcShape, destShape;
5017  if (inferStaticShape(unPackOp, srcShape, destShape)) {
5018  Location loc = unPackOp.getLoc();
5019  Value source = unPackOp.getSource();
5020  if (srcShape != unPackOp.getSourceType().getShape()) {
5021  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5022  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
5023  unPackOp.getSource());
5024  }
5025  Value dest = unPackOp.getDest();
5026  if (destShape != unPackOp.getDestType().getShape()) {
5027  auto newDestType = unPackOp.getDestType().clone(destShape);
5028  dest =
5029  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5030  }
5031  Value newOp = rewriter.create<UnPackOp>(
5032  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5033  unPackOp.getOuterDimsPerm());
5034  rewriter.replaceOpWithNewOp<tensor::CastOp>(
5035  unPackOp, unPackOp.getResult().getType(), newOp);
5036  return success();
5037  }
5038 
5039  return failure();
5040 }
5041 
5042 bool UnPackOp::isLikeUnPad() {
5043  RankedTensorType packedTensorType = getSourceType();
5044  return isLikePadUnPad(*this, packedTensorType);
5045 }
5046 
5047 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5048  if (OpFoldResult reshapedSource = reshapeConstantSource(
5049  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5050  getResult().getType()))
5051  return reshapedSource;
5052  return {};
5053 }
5054 
5055 /// Folds a tensor.cast op into a consuming UnPackOp op if the
5056 /// `tensor.cast` has source that is more static than the consuming op.
5057 ///
5058 /// Example:
5059 /// ```mlir
5060 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
5061 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
5062 /// ```
5063 ///
5064 /// folds into:
5065 ///
5066 /// ```mlir
5067 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
5068 /// ```
5069 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5071 
5072  LogicalResult matchAndRewrite(UnPackOp op,
5073  PatternRewriter &rewriter) const override {
5075  return failure();
5076 
5077  SmallVector<Type> newResultTypes(op->getResultTypes());
5078  SmallVector<Value> newOperands =
5080  Value sourceTensor = newOperands[0];
5081 
5082  // Get the updated mixed-tile-sizes attribute.
5083  SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
5084  rewriter, sourceTensor.getType(), op.getMixedTiles());
5085 
5086  // Clone op.
5087  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5088  // this point. However, in practice, we use them for things that we'd like
5089  // to preserve. Implement a better abstraction.
5090  UnPackOp newOp = rewriter.create<UnPackOp>(
5091  op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5092  newMixedTileSizes, op.getOuterDimsPerm());
5093  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5094 
5095  // Replace op.
5096  Value oldResult = op.getResult();
5097  Value newResult = newOp.getResult();
5098  Value replacement = (newResult.getType() != oldResult.getType())
5099  ? rewriter.create<tensor::CastOp>(
5100  op->getLoc(), oldResult.getType(), newResult)
5101  : newResult;
5102 
5103  rewriter.replaceOp(op, {replacement});
5104 
5105  return success();
5106  }
5107 };
5108 
5109 } // namespace linalg
5110 } // namespace mlir
5111 
5112 //===----------------------------------------------------------------------===//
5113 // LinalgDialect
5114 //===----------------------------------------------------------------------===//
5115 
5116 void LinalgDialect::getCanonicalizationPatterns(
5117  RewritePatternSet &results) const {
5118  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
5119  FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
5120 }
5121 
5123  Attribute value, Type type,
5124  Location loc) {
5125  return arith::ConstantOp::materialize(builder, value, type, loc);
5126 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3444
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1852
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:310
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2785
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2856
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
Definition: LinalgOps.cpp:3422
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4277
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:126
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2324
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4276
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
Definition: LinalgOps.cpp:3437
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:298
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp has exactly 3 result di...
Definition: LinalgOps.cpp:3501
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1685
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1733
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2830
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:188
static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
Definition: LinalgOps.cpp:3524
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1485
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:160
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1354
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:204
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2807
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1245
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1212
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2235
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4275
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:336
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:987
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:329
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:59
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1408
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1514
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:366
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
Definition: LinalgOps.cpp:3468
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:373
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:224
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:222
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:413
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_iterator result_end()
Definition: Operation.h:414
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:96
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1148
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1198
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2630
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: LinalgOps.cpp:4097
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: LinalgOps.cpp:4645
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: LinalgOps.cpp:4738
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: LinalgOps.cpp:4438
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: LinalgOps.cpp:4153
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: LinalgOps.cpp:4170
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4140
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: LinalgOps.cpp:4453
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2316
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: LinalgOps.cpp:4283
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:106
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2357
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: LinalgOps.cpp:4631
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2296
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2307
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: LinalgOps.cpp:4109
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
Definition: LinalgOps.cpp:4065
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:4618
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:97
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:4604
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: LinalgOps.cpp:4573
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4124
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: LinalgOps.cpp:4185
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
Definition: LinalgOps.cpp:3626
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:358
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:351
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:367
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:69
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1297
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1995
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1998
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2020
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2023
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
Definition: LinalgOps.cpp:4801
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:4804
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
Definition: LinalgOps.cpp:5069
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5072