MLIR 23.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/TypeSwitch.h"
31
32#include <numeric>
33
34using namespace mlir;
35using namespace mlir::tosa;
36
37#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
39
40//===----------------------------------------------------------------------===//
41// Tosa dialect interface includes.
42//===----------------------------------------------------------------------===//
43
44#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
45#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
48
49namespace {
50#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
51
52//===----------------------------------------------------------------------===//
53// Dialect Function Inliner Interface.
54//===----------------------------------------------------------------------===//
55struct TosaInlinerInterface : public DialectInlinerInterface {
56 using DialectInlinerInterface::DialectInlinerInterface;
57
58 //===--------------------------------------------------------------------===//
59 // Analysis Hooks.
60 //===--------------------------------------------------------------------===//
61
62 /// All operations can be inlined by default.
63 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
64 IRMapping &map) const final {
65 return true;
66 }
67
68 /// All regions with If and While parent operators can be inlined.
69 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
70 IRMapping &map) const final {
71 return (isa<tosa::IfOp>(dest->getParentOp()) ||
72 isa<tosa::WhileOp>(dest->getParentOp()));
73 }
74};
75
76/// This class implements the bytecode interface for the Tosa dialect.
77struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
78 TosaDialectBytecodeInterface(Dialect *dialect)
79 : BytecodeDialectInterface(dialect) {}
80
81 //===--------------------------------------------------------------------===//
82 // Attributes
83
84 Attribute readAttribute(DialectBytecodeReader &reader) const override {
85 return ::readAttribute(getContext(), reader);
86 }
87
88 LogicalResult writeAttribute(Attribute attr,
89 DialectBytecodeWriter &writer) const override {
90 return ::writeAttribute(attr, writer);
91 }
92
93 //===--------------------------------------------------------------------===//
94 // Types
95
96 Type readType(DialectBytecodeReader &reader) const override {
97 return ::readType(getContext(), reader);
98 }
99
100 LogicalResult writeType(Type type,
101 DialectBytecodeWriter &writer) const override {
102 return ::writeType(type, writer);
103 }
104
105 void writeVersion(DialectBytecodeWriter &writer) const final {
106 // TODO: Populate.
107 }
108
109 std::unique_ptr<DialectVersion>
110 readVersion(DialectBytecodeReader &reader) const final {
111 // TODO: Populate
112 reader.emitError("Dialect does not support versioning");
113 return nullptr;
114 }
115
116 LogicalResult upgradeFromVersion(Operation *topLevelOp,
117 const DialectVersion &version) const final {
118 return success();
119 }
120};
121
122} // namespace
123
124//===----------------------------------------------------------------------===//
125// TOSA control flow support.
126//===----------------------------------------------------------------------===//
127
128/// Returns the while loop body.
129SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
130 return {&getBodyGraph()};
131}
132
133//===----------------------------------------------------------------------===//
134// TOSA variable operator support.
135//===----------------------------------------------------------------------===//
136
138 return to_vector(llvm::map_range(shape, [](int64_t dim) {
139 return dim == -1 ? ShapedType::kDynamic : dim;
140 }));
141}
142
143// returns type of variable op
144RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
145 Type elementType = variableOp.getType();
146 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
147 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
148 return RankedTensorType::get(shape, elementType);
149}
150
151//===----------------------------------------------------------------------===//
152// Tosa dialect initialization.
153//===----------------------------------------------------------------------===//
154
155void TosaDialect::initialize() {
156 addTypes<
157#define GET_TYPEDEF_LIST
158#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
159 >();
160 addOperations<
161#define GET_OP_LIST
162#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
163 >();
164 addAttributes<
165#define GET_ATTRDEF_LIST
166#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
167 >();
168 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
169 declarePromisedInterfaces<
170 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
171 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
172 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
173 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
174 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
175 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
176 GreaterEqualOp, MatMulOp>();
177}
178
179Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
180 Type type, Location loc) {
181 // Tosa dialect constants only support ElementsAttr unlike standard dialect
182 // constant which supports all attributes.
183 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
184 return tosa::ConstShapeOp::create(builder, loc, type,
185 llvm::cast<DenseIntElementsAttr>(value));
186 }
187 if (llvm::isa<ElementsAttr>(value))
188 return tosa::ConstOp::create(builder, loc, type,
189 llvm::cast<ElementsAttr>(value));
190 return nullptr;
191}
192
193//===----------------------------------------------------------------------===//
194// Parsers and printers
195//===----------------------------------------------------------------------===//
196
197namespace {
198
199ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
200 DenseElementsAttr &varShapeAttr,
201 TypeAttr &typeAttr) {
202 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
203 if (!shapedType.hasRank())
204 return parser.emitError(parser.getCurrentLocation())
205 << "expected ranked type";
206
207 auto elementType = shapedType.getElementType();
208 typeAttr = TypeAttr::get(elementType);
209 ArrayRef<int64_t> shape = shapedType.getShape();
210 Builder builder(parser.getContext());
211 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
212 return success();
213 }
214 return parser.emitError(parser.getCurrentLocation())
215 << "expected shaped type";
216}
217
218} // namespace
219
220// parses the optional initial value or type for a tosa variable
221// with initial value:
222// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
223//
224// without initial value:
225// tosa.variable @name : tensor<1x8xf32>
227 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
228 Attribute &initialValueAttr) {
229 if (succeeded(parser.parseOptionalEqual())) {
230 if (failed(parser.parseAttribute(initialValueAttr))) {
231 return parser.emitError(parser.getCurrentLocation())
232 << "expected attribute";
233 }
234 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
235 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
236 typeAttr);
237 }
238 return parser.emitError(parser.getCurrentLocation())
239 << "expected Typed attr";
240 }
241
242 initialValueAttr = nullptr;
243 Type parsedType;
244 if (failed(parser.parseColonType(parsedType))) {
245 return parser.emitError(parser.getCurrentLocation())
246 << "expected type after colon";
247 }
248 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
249}
250
252 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
253 TypeAttr typeAttr, Attribute initialValueAttr) {
254 bool needsSpace = false;
255 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
256 auto shape =
257 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
258 Type elementType = typeAttr.getValue();
259 RankedTensorType tensorType =
260 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
261 auto tensorTypeAttr = TypeAttr::get(tensorType);
262 p << ": ";
263 p.printAttribute(tensorTypeAttr);
264 needsSpace = true; // subsequent attr value needs a space separator
265 }
266 if (initialValueAttr) {
267 if (needsSpace)
268 p << ' ';
269 p << "= ";
270 p.printAttribute(initialValueAttr);
271 }
272}
273
274namespace {
275
276// parse attributes with special handling for tosa enum attributes
277template <typename EnumType>
278ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
279 NamedAttrList &outAttrs) {
280 llvm::StringRef name;
281 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
282 return failure();
283
284 // special handling: rounding_mode accepts a *bare* RoundingMode enum
285 // keyword.
286 llvm::StringRef kw;
287 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
288 if (name == "rounding_mode" &&
289 succeeded(parser.parseOptionalKeyword(&kw))) {
290 auto sym = symbolizeRoundingMode(kw);
291 if (!sym)
292 return parser.emitError(parser.getCurrentLocation())
293 << "invalid rounding_mode value: " << kw;
294 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
295 outAttrs.push_back(NamedAttribute(name, attr));
296 return success();
297 }
298 }
299 // special handling: mode accepts a *bare* ResizeMode enum keyword.
300 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
301 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
302 auto sym = symbolizeResizeMode(kw);
303 if (!sym)
304 return parser.emitError(parser.getCurrentLocation())
305 << "invalid resize mode value: " << kw;
306 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
307 outAttrs.push_back(NamedAttribute(name, attr));
308 return success();
309 }
310 }
311 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
312 // keyword.
313 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
314 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
315 auto sym = symbolizeNanPropagationMode(kw);
316 if (!sym)
317 return parser.emitError(parser.getCurrentLocation())
318 << "invalid nan_mode value: " << kw;
319 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
320 outAttrs.push_back(NamedAttribute(name, attr));
321 return success();
322 }
323 }
324
325 // special handling: block_size accepts a *bare* BlockSizeMode enum
326 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
327 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
328 auto sym = symbolizeBlockSize(kw);
329 if (!sym)
330 return parser.emitError(parser.getCurrentLocation())
331 << "invalid block_size value: " << kw;
332 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
333 outAttrs.push_back(NamedAttribute(name, attr));
334 return success();
335 }
336 }
337
338 // Default path: parse any normal attribute literal, including fully qualified
339 // enum keyword
340 Attribute attr;
341 return parser.parseAttribute(attr, name, outAttrs);
342}
343
344template <typename EnumType>
345ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
346 // parse operands
348 if (parser.parseCommaSeparatedList(
349 [&]() { return parser.parseOperand(operands.emplace_back()); }))
350 return failure();
351
352 // Parse { attr-dict } with special handling for enum bare token
353 NamedAttrList attrs;
354 if (succeeded(parser.parseOptionalLBrace()) &&
355 failed(parser.parseOptionalRBrace())) {
356 do {
357 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
358 return failure();
359 } while (succeeded(parser.parseOptionalComma()));
360 if (parser.parseRBrace())
361 return failure();
362 }
363
364 FunctionType fnTy;
365 if (parser.parseColonType(fnTy))
366 return failure();
367
368 // Resolve operands and types
369 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
370 parser.getCurrentLocation(),
371 result.operands)))
372 return failure();
373
374 result.addTypes(fnTy.getResults());
375 result.addAttributes(attrs);
376
377 return success();
378}
379
380void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
381 parser << namedAttr.getName().strref() << " = ";
382 auto attr = namedAttr.getValue();
383 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
384 parser << roundingModeAttr.getValue();
385 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
386 parser << resizeModeAttr.getValue();
387 } else if (auto nanPropagationModeAttr =
388 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
389 parser << nanPropagationModeAttr.getValue();
390 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
391 parser << blockSizeAttr.getValue();
392 } else {
393 parser.printAttribute(attr);
394 }
395}
396
397// print with special handling for default valued NanPropagationMode attribute
398void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
399 parser << " ";
400 parser.printOperands(op->getOperands());
401
402 NamedAttrList toPrint(op->getAttrs());
403 // remove default NanPropagate attribute
404 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
405 for (auto attr : op->getAttrs()) {
406 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
407 if (nanAttr.getValue() == kDefaultNanValue) {
408 // elide from toPrint
409 toPrint.erase(attr.getName());
410 break;
411 }
412 }
413 }
414
415 if (!toPrint.empty()) {
416 parser << " {";
417 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
418 printNamedAttr(parser, namedAttr);
419 });
420 parser << "}";
421 }
422
423 parser << " : ";
424 parser.printFunctionalType(op);
425}
426
427// print with special handling for enums: RoundingMode, ResizeMode
428void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
429 parser << " ";
430 parser.printOperands(op->getOperands());
431
432 if (!op->getAttrs().empty()) {
433 parser << " {";
434 llvm::interleaveComma(op->getAttrs(), parser,
435 [&](const NamedAttribute namedAttr) {
436 printNamedAttr(parser, namedAttr);
437 });
438 parser << "}";
439 }
440
441 parser << " : ";
442 parser.printFunctionalType(op);
443}
444
445} // namespace
446
447ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
448 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
449}
450
451void RescaleOp::print(OpAsmPrinter &parser) {
452 printWithEnumHandling(parser, *this);
453}
454
455ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
456 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
457}
458
459void ApplyScaleOp::print(OpAsmPrinter &parser) {
460 printWithEnumHandling(parser, *this);
461}
462
463ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
464 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
465}
466
467void ResizeOp::print(OpAsmPrinter &parser) {
468 printWithEnumHandling(parser, *this);
469}
470
471ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
472 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
473}
474
475void ArgMaxOp::print(OpAsmPrinter &parser) {
476 printWithNanPropagationHandling(parser, *this);
477}
478
479ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
480 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
481}
482
483void MaxPool2dOp::print(OpAsmPrinter &parser) {
484 printWithNanPropagationHandling(parser, *this);
485}
486
487ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
488 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
489}
490
491void ClampOp::print(OpAsmPrinter &parser) {
492 printWithNanPropagationHandling(parser, *this);
493}
494
495ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
496 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
497}
498
499void MaximumOp::print(OpAsmPrinter &parser) {
500 printWithNanPropagationHandling(parser, *this);
501}
502
503ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
504 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
505}
506
507void MinimumOp::print(OpAsmPrinter &parser) {
508 printWithNanPropagationHandling(parser, *this);
509}
510
511ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
512 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
513}
514
515void ReduceMaxOp::print(OpAsmPrinter &parser) {
516 printWithNanPropagationHandling(parser, *this);
517}
518
519ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
520 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
521}
522
523void ReduceMinOp::print(OpAsmPrinter &parser) {
524 printWithNanPropagationHandling(parser, *this);
525}
526
527ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
529 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
530}
531
532void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
533 printWithEnumHandling(parser, *this);
534}
535
536ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
538 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
539}
540
541void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
542 printWithEnumHandling(parser, *this);
543}
544
545ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
547 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
548}
549
550void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
551 printWithEnumHandling(parser, *this);
552}
553
554ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
556 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
557}
558
559void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
560 printWithEnumHandling(parser, *this);
561}
562
563//===----------------------------------------------------------------------===//
564// Tosa utilities.
565//===----------------------------------------------------------------------===//
566
567static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
568 if (lhs % rhs != 0)
569 return std::nullopt;
570 return lhs / rhs;
571}
572
574 auto srcType = getElementTypeOrSelf(type);
575 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
576 srcType = getStorageElementTypeFromQuantized(quantType);
577 return srcType;
578}
579
583
584static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
585 Value valZp, StringRef name) {
587 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
588
589 bool bothInts =
590 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
591 bool sameBitWidth =
592 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
593
594 if (!bothInts || !sameBitWidth) {
595 return op->emitOpError()
596 << "expected " << name << " and " << name
597 << "_zp to both be integer of the same bitwidth, but got " << eType
598 << " vs. " << eZpType;
599 }
600 return success();
601}
602
603// Create a pad-const const tensor with value of `val` of required data-type
605 Value src, int32_t val) {
606 const auto srcType = getElementTypeOrSelf(src);
607 const auto srcElemType = getStorageElementTypeOrSelf(src);
608 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
609 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
610 const auto padConstAttr{
611 llvm::isa<FloatType>(srcElemType)
612 ? DenseElementsAttr::get(padConstEType,
613 builder.getFloatAttr(srcElemType, val))
614 : DenseElementsAttr::get(padConstEType,
615 builder.getIntegerAttr(srcElemType, val))};
616 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
617}
618
620 if (dyn_cast<tosa::mxint8Type>(type))
621 return 8;
622 return type.getIntOrFloatBitWidth();
623}
624
625// Update dim size if current dim is dynamic, otherwise raise an error if sizes
626// do not match
627LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
628 const int64_t newDim,
629 const StringRef operandName,
630 const StringRef dimName) {
631 if (ShapedType::isDynamic(currDim)) {
632 currDim = newDim;
633 return success();
634 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
635 return op->emitOpError("expected ")
636 << dimName << " of " << operandName << " to match size " << currDim
637 << ", got " << newDim;
638 }
639 return success();
640}
641
643 Operation *op, const int64_t inputSize, const int64_t kernelSize,
644 const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
645 const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
646 const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
647 const llvm::StringRef padAfterName) {
648 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
649 return success();
650
651 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
652
653 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
654 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
655 stride);
656 if (!calculatedOutSizeMinusOne.has_value())
657 return op->emitOpError("expected input_")
658 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
659 << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
660 << dimAxis << " to be wholly divisible by stride_" << dimAxis
661 << ", got (" << inputSize << " - 1 + " << padBefore << " + "
662 << padAfter << " - (" << kernelSize << " - 1) * " << dilation
663 << ") / " << stride;
664
665 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
666 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
667 return op->emitOpError("calculated output ")
668 << dimName << " did not match expected: "
669 << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
670
671 return success();
672}
673
674//===----------------------------------------------------------------------===//
675// TOSA Operator Verifiers.
676//===----------------------------------------------------------------------===//
677
678template <typename T>
679static LogicalResult verifyConvOp(T op) {
680 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
681 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
682
683 auto inputEType = inputType.getElementType();
684 auto weightEType = weightType.getElementType();
685 auto biasEType =
686 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
687 auto resultEType =
688 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
689 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
690 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
691
692 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
693 inputEType = getStorageElementTypeFromQuantized(quantType);
694
695 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
696 weightEType = getStorageElementTypeFromQuantized(quantType);
697
698 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
699 biasEType = getStorageElementTypeFromQuantized(quantType);
700
701 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
702 resultEType = getStorageElementTypeFromQuantized(quantType);
703
704 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
705 // for now, only enforce bias element type == result element type for
706 // float types.
707 op.emitOpError(
708 "expect both bias and result to have same element type, got ")
709 << biasEType << " and " << resultEType;
710 return failure();
711 }
712
713 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
714 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
715 if (inputEType != weightEType) {
716 op.emitOpError(
717 "expect both input and weight to have same element type, got ")
718 << inputEType << " and " << weightEType;
719 return failure();
720 }
721 }
722
723 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
724 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
725
726 // Either both must be float or both non-float.
727 if (inputIsFloat != weightIsFloat) {
728 op.emitOpError(
729 "expect both input and weight to be float or not together, got ")
730 << inputEType << " and " << weightEType;
731 return failure();
732 }
733
734 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
735 if (inputEType != inputZpEType) {
736 return op.emitOpError("expect both input and its zero point are the same "
737 "element type, got ")
738 << inputEType << " and " << inputZpEType;
739 }
740
741 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
742 if (weightEType != weightZpEType) {
743 return op.emitOpError("expect both weight and its zero point are the same "
744 "element type, got ")
745 << weightEType << " and " << weightZpEType;
746 }
747
748 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
749 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
750 return failure();
751
752 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
753 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
754 return failure();
755
756 return success();
757}
758
759LogicalResult tosa::ConstOp::verify() {
760
761 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
762 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
763
764 if (!attrType || !outputType) {
765 emitOpError("expected tensors for attr/result type");
766 return failure();
767 }
768
769 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
770 outputType.getElementType())) {
771 if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
772 return success();
773 }
774
775 if (attrType.getElementType() != outputType.getElementType()) {
776 emitOpError("expected same attr/result element types");
777 return failure();
778 }
779
780 return success();
781}
782
783template <typename T>
784static LogicalResult verifyConvOpModes(T op) {
785 auto inputEType =
786 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
787
788 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
789 inputEType = getStorageElementTypeFromQuantized(quantType);
790
791 auto accType = op.getAccType();
792 if (inputEType.isInteger(8) && !accType.isInteger(32))
793 return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
794 << accType;
795
796 if (inputEType.isInteger(16) && !accType.isInteger(48))
797 return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
798 << accType;
799
800 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
801 !(accType.isF16() || accType.isF32()))
802 return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
803 << accType;
804
805 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
806 return op.emitOpError(
807 "accumulator type for f16 tensor is not f16/f32, got ")
808 << accType;
809
810 if (inputEType.isBF16() && !accType.isF32())
811 return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
812 << accType;
813
814 if (inputEType.isF32() && !accType.isF32())
815 return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
816 << accType;
817
818 auto resultEType =
819 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
820
821 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
822 resultEType = getStorageElementTypeFromQuantized(quantType);
823
824 return success();
825}
826
827//===----------------------------------------------------------------------===//
828// ERROR_IF functions.
829// ERROR_IF is a predicate that must set an error if the condition holds.
830//===----------------------------------------------------------------------===//
831
832template <typename T>
833static LogicalResult verifyConvOpErrorIf(T op) {
834 llvm::ArrayRef<int64_t> padding = op.getPad();
835 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
836 return op.emitOpError("expect all padding values to be >= 0, got ")
837 << padding;
838
839 llvm::ArrayRef<int64_t> strides = op.getStride();
840 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
841 return op.emitOpError("expect all stride values to be >= 1, got ")
842 << strides;
843
844 llvm::ArrayRef<int64_t> dilations = op.getDilation();
845 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
846 return op.emitOpError("expect all dilation values to be >= 1, got ")
847 << dilations;
848
849 const RankedTensorType outputType =
850 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
851 if (!outputType)
852 // Skip following checks if output is not ranked
853 return success();
854
855 const RankedTensorType inputType =
856 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
857 const RankedTensorType weightType =
858 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
859
860 if (inputType && weightType) {
861 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
862 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
863 if (failed(verifyConvOutputSize(
864 op, inputType.getDimSize(1), weightType.getDimSize(1),
865 outputType.getDimSize(1), padding[0], padding[1], strides[0],
866 dilations[0], "height", "y", "top", "bottom")))
867 return failure();
868
869 if (failed(verifyConvOutputSize(
870 op, inputType.getDimSize(2), weightType.getDimSize(2),
871 outputType.getDimSize(2), padding[2], padding[3], strides[1],
872 dilations[1], "width", "x", "left", "right")))
873 return failure();
874 }
875
876 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
877 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
878 if (failed(verifyConvOutputSize(
879 op, inputType.getDimSize(1), weightType.getDimSize(0),
880 outputType.getDimSize(1), padding[0], padding[1], strides[0],
881 dilations[0], "height", "y", "top", "bottom")))
882 return failure();
883
884 if (failed(verifyConvOutputSize(
885 op, inputType.getDimSize(2), weightType.getDimSize(1),
886 outputType.getDimSize(2), padding[2], padding[3], strides[1],
887 dilations[1], "width", "x", "left", "right")))
888 return failure();
889 }
890
891 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
892 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
893 if (failed(verifyConvOutputSize(
894 op, inputType.getDimSize(1), weightType.getDimSize(1),
895 outputType.getDimSize(1), padding[0], padding[1], strides[0],
896 dilations[0], "depth", "d", "front", "back")))
897 return failure();
898
899 if (failed(verifyConvOutputSize(
900 op, inputType.getDimSize(2), weightType.getDimSize(2),
901 outputType.getDimSize(2), padding[2], padding[3], strides[1],
902 dilations[1], "height", "y", "top", "bottom")))
903 return failure();
904
905 if (failed(verifyConvOutputSize(
906 op, inputType.getDimSize(3), weightType.getDimSize(3),
907 outputType.getDimSize(3), padding[4], padding[5], strides[2],
908 dilations[2], "width", "x", "left", "right")))
909 return failure();
910 }
911 }
912
913 const RankedTensorType biasType =
914 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
915 if (!biasType)
916 // Skip following checks if bias is not ranked
917 return success();
918
919 const int64_t biasChannels = biasType.getDimSize(0);
920 const int64_t outputChannels =
921 outputType.getDimSize(outputType.getRank() - 1);
922 if (biasChannels == ShapedType::kDynamic ||
923 outputChannels == ShapedType::kDynamic)
924 // Skip following checks if biasChannels or outputChannels is dynamic dim
925 return success();
926
927 if (biasChannels != outputChannels && biasChannels != 1)
928 return op.emitOpError(
929 "bias channels expected to be equal to output channels (")
930 << outputChannels << ") or 1, got " << biasChannels;
931
932 return success();
933}
934
935// Verify whether same type and shape of the given two types.
936static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
937 StringRef name1, Type type2,
938 StringRef name2) {
939 auto shapeType1 = dyn_cast<ShapedType>(type1);
940 auto shapeType2 = dyn_cast<ShapedType>(type2);
941 if (!shapeType1 || !shapeType2)
942 return failure();
943
944 auto elemType1 = shapeType1.getElementType();
945 auto elemType2 = shapeType2.getElementType();
946 if (elemType1 != elemType2)
947 return op->emitOpError()
948 << "require same element type for " << name1 << " (" << elemType1
949 << ") and " << name2 << " (" << elemType2 << ")";
950
951 if (failed(verifyCompatibleShape(type1, type2)))
952 return op->emitOpError()
953 << "require same shapes for " << name1 << " (" << type1 << ") and "
954 << name2 << " (" << type2 << ")";
955
956 return success();
957}
958
959// Verify whether same length, type, and shape of the given two tensor lists.
960static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
961 StringRef name1,
962 ValueRange list2,
963 StringRef name2) {
964 if (list1.size() != list2.size())
965 return op->emitOpError()
966 << "require same number of values in " << name1 << " ("
967 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
968
969 for (auto [type1, type2] :
970 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
971 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
972 return failure();
973 }
974
975 return success();
976}
977
978static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
979 ShapeAdaptor shapeAdaptor(type);
980 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
981 return success();
982
983 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
984}
985
986template <typename T>
987static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
988 Operation *symTableOp =
989 op->template getParentWithTrait<OpTrait::SymbolTable>();
990 if (!symTableOp)
991 // If the operation is not the scope of a symbol table, we cannot
992 // verify it against it's declaration.
993 return success();
994
995 SymbolTable symTable(symTableOp);
996 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
997
998 // Verify prior declaration
999 if (!varOp)
1000 return op->emitOpError("'")
1001 << op.getName() << "' has not been declared by 'tosa.variable'";
1002
1003 // Verify type and shape
1004 auto variableType = getVariableType(varOp);
1005 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
1006 "the input tensor")
1007 .failed())
1008 return failure();
1009 return success();
1010}
1011
1012// verify that inType and outType have same element types
1013template <typename T>
1014static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
1015 StringRef aName = "input",
1016 StringRef bName = "output") {
1017 auto aTType = llvm::dyn_cast<TensorType>(aType);
1018 auto bTType = llvm::dyn_cast<TensorType>(bType);
1019 if (!aTType) {
1020 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
1021 return failure();
1022 }
1023 if (!bTType) {
1024 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
1025 return failure();
1026 }
1027 auto aElementType = aTType.getElementType();
1028 auto bElementType = bTType.getElementType();
1029 auto aQuantType =
1030 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1031 auto bQuantType =
1032 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1033 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1034 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1035 aElementType != bElementType) {
1036 // only check if both element types are int/index/float/UniformQuantized
1037 // eg, not sure how to check quant::QuantizedType
1038 // this happens in test_conv2d_q_grouped_convolution in
1039 // tfl-to-tosa-pipeline.mlir
1040 op.emitOpError("expect ")
1041 << aName << " and " << bName << " to have same element type, got "
1042 << aElementType << " and " << bElementType;
1043 return failure();
1044 }
1045 return success();
1046}
1047
1048LogicalResult tosa::ArgMaxOp::verify() {
1049 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1050
1051 // Ensure output is of 32-bit integer
1052 if (const auto resultETy = resultType.getElementType();
1053 !resultETy.isIntOrIndex())
1054 return emitOpError("result tensor is not of integer type");
1055
1056 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1057 if (!inputType.hasRank())
1058 return success();
1059
1060 // Ensure axis is within the tensor rank
1061 const int64_t axis = getAxisAttr().getInt();
1062 if (((axis < 0) || axis >= inputType.getRank()))
1063 return emitOpError("specified axis is outside the rank of the tensor");
1064
1065 if (!resultType.hasRank())
1066 return success();
1067
1068 const ArrayRef<int64_t> inputShape = inputType.getShape();
1069 const ArrayRef<int64_t> outputShape = resultType.getShape();
1070 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1071 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1072 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1073 return emitOpError("expected output shape '")
1074 << expectedOutputShape << "', got '" << outputShape << "'";
1075
1076 return success();
1077}
1078
1079template <typename T>
1080static LogicalResult verifyPoolingOp(T op) {
1081 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1082 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1083 return op.emitOpError("expect all kernel values to be >= 1, got ")
1084 << kernel;
1085
1086 const llvm::ArrayRef<int64_t> strides = op.getStride();
1087 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1088 return op.emitOpError("expect all stride values to be >= 1, got ")
1089 << strides;
1090
1091 const llvm::ArrayRef<int64_t> padding = op.getPad();
1092 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1093 return op.emitOpError("expect all padding values to be >= 0, got ")
1094 << padding;
1095
1096 // Padding must be less than kernel size to avoid a divide-by-zero
1097 const int64_t kernelX = kernel[1];
1098 const int64_t padLeft = padding[2];
1099 const int64_t padRight = padding[3];
1100 if (padRight >= kernelX || padLeft >= kernelX)
1101 return op.emitOpError("expected left/right padding to be less than the "
1102 "width of the kernel, got pad_left=")
1103 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1104
1105 const int64_t kernelY = kernel[0];
1106 const int64_t padTop = padding[0];
1107 const int64_t padBottom = padding[1];
1108 if (padTop >= kernelY || padBottom >= kernelY)
1109 return op.emitOpError("expected top/bottom padding to be less than the "
1110 "height of the kernel, got pad_top=")
1111 << padTop << ", pad_bottom=" << padBottom
1112 << ", kernel_y=" << kernelY;
1113
1114 const auto inputType =
1115 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1116 const auto outputType =
1117 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1118 if (!inputType || !outputType)
1119 return success();
1120
1121 const auto verifyOutputSize =
1122 [&op](const int64_t inputSize, const int64_t outputSize,
1123 const int64_t kernelSize, const int64_t strideSize,
1124 const int64_t padBefore, const int64_t padAfter,
1125 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1126 const llvm::StringRef padBeforeName,
1127 const llvm::StringRef padAfterName) -> LogicalResult {
1128 if (ShapedType::isDynamic(inputSize))
1129 return success();
1130
1131 const std::optional<int64_t> calculatedOutSizeMinusOne =
1132 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1133 if (!calculatedOutSizeMinusOne.has_value())
1134 return op.emitOpError("expected input_")
1135 << dimName << " + pad_" << padBeforeName << " + pad_"
1136 << padAfterName << " - kernel_" << dimAxis
1137 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1138 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1139 << kernelSize << ") / " << strideSize;
1140
1141 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1142 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1143 return op.emitOpError("calculated output ")
1144 << dimName << " did not match expected: " << "calculated="
1145 << calculatedOutSize << ", expected=" << outputSize;
1146
1147 return success();
1148 };
1149
1150 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1151 kernel[0], strides[0], padding[0], padding[1],
1152 "height", "y", "top", "bottom")))
1153 return failure();
1154
1155 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1156 kernel[1], strides[1], padding[2], padding[3],
1157 "width", "x", "left", "right")))
1158 return failure();
1159
1160 return success();
1161}
1162
1163LogicalResult tosa::AvgPool2dOp::verify() {
1164 if (failed(verifyPoolingOp(*this)))
1165 return failure();
1166
1167 const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1168 const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1169 const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1170 const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1171
1172 auto accType = getAccType();
1173 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1174 return emitOpError("accumulator type for integer tensor is not i32");
1175
1176 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1177 return emitOpError("accumulator type for f16 tensor is not f16/f32");
1178
1179 if (inputETy.isBF16() && !accType.isF32())
1180 return emitOpError("accumulator type for bf16 tensor is not f32");
1181
1182 if (inputETy.isF32() && !accType.isF32())
1183 return emitOpError("accumulator type for f32 tensor is not f32");
1184
1185 if (inputETy != inputZpETy)
1186 return emitOpError("expect both input and its zero point are the same "
1187 "element type, got ")
1188 << inputETy << " and " << inputZpETy;
1189
1190 if (resultETy != outputZpETy)
1191 return emitOpError("expect both output and its zero point are the same "
1192 "element type, got ")
1193 << resultETy << " and " << outputZpETy;
1194
1195 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1196 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1197 return failure();
1198
1199 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1200 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1201 return failure();
1202
1203 return success();
1204}
1205
1206LogicalResult tosa::ClampOp::verify() {
1207 mlir::Type inputETy =
1208 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1209 if (auto quantType =
1210 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1211 inputETy = getStorageElementTypeFromQuantized(quantType);
1212 }
1213 mlir::Type outputETy =
1214 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1215 if (auto quantType =
1216 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1217 outputETy = getStorageElementTypeFromQuantized(quantType);
1218 }
1219 if (inputETy != outputETy)
1220 return emitOpError("input/output element types are incompatible.");
1221
1222 auto maxValAttr = getMaxValAttr();
1223 auto minValAttr = getMinValAttr();
1224
1225 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1226
1227 if (inputETy.isInteger(dataTypeBitWidth)) {
1228 // if input datatype is integer, check that the min_val/max_val attributes
1229 // are integer attributes, and that their type is the same as the input's
1230 // datatype
1231 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1232 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1233 if (!intMaxValAttr || !intMinValAttr ||
1234 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1235 (intMaxValAttr.getType() != inputETy))
1236 return emitOpError("min/max attributes types are incompatible with "
1237 "input/output element types.");
1238
1239 const bool isUnsigned = inputETy.isUnsignedInteger();
1240 const bool isBoolean = inputETy.isInteger(1);
1241 const APInt minVal = intMinValAttr.getValue();
1242 const APInt maxVal = intMaxValAttr.getValue();
1243 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1244 return emitOpError("expected min_val <= max_val, got min_val=")
1245 << minValAttr << ", max_val=" << maxValAttr;
1246 } else {
1247 // otherwise, input datatype is float, check that the min_val/max_val
1248 // attributes share the same type and that their type is the same as the
1249 // input's datatype
1250 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1251 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1252 if (!floatMaxValAttr || !floatMinValAttr ||
1253 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1254 (floatMaxValAttr.getType() != inputETy))
1255 return emitOpError("min/max attributes types are incompatible with "
1256 "input/output element types.");
1257
1258 const APFloat minVal = floatMinValAttr.getValue();
1259 const APFloat maxVal = floatMaxValAttr.getValue();
1260 if (minVal.isNaN() || maxVal.isNaN())
1261 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1262 << minValAttr << ", max_val=" << maxValAttr;
1263
1264 if (maxVal < minVal)
1265 return emitOpError("expected min_val <= max_val, got min_val=")
1266 << minValAttr << ", max_val=" << maxValAttr;
1267 }
1268
1269 return success();
1270}
1271
1272//===----------------------------------------------------------------------===//
1273// TOSA Operator Quantization Builders.
1274//===----------------------------------------------------------------------===//
1275
1276/// This builder is called on all convolution operators except TransposeConv,
1277/// which has specialized output shape semantics. The builder also defines the
1278/// bitwidth of the output given the bit width of the input & weight content.
1280 Type outputType, Value input, Value weight,
1281 Value bias, DenseI64ArrayAttr pad,
1282 DenseI64ArrayAttr stride,
1283 DenseI64ArrayAttr dilation,
1284 TypeAttr accType) {
1285 auto zps = createZPsAsConst(builder, input, weight);
1286 result.addOperands({input, weight, bias, zps.first, zps.second});
1287 result.addAttribute("pad", pad);
1288 result.addAttribute("stride", stride);
1289 result.addAttribute("dilation", dilation);
1290 result.addAttribute("acc_type", accType);
1291 Type finalOutputType = outputType;
1292 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1293 if (quantAttr) {
1294 finalOutputType =
1295 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1296 }
1297 result.addTypes(finalOutputType);
1298}
1299
1300/// Handles tosa.transpose_conv2d which has outpad and output shape
1301/// attributes.
1302static void
1304 Type outputType, Value input, Value weight,
1305 Value bias, DenseI64ArrayAttr outpad,
1306 DenseI64ArrayAttr stride, TypeAttr accType) {
1307 auto zps = createZPsAsConst(builder, input, weight);
1308 result.addOperands({input, weight, bias, zps.first, zps.second});
1309 result.addAttribute("out_pad", outpad);
1310 result.addAttribute("stride", stride);
1311 result.addAttribute("acc_type", accType);
1312 Type finalOutputType = outputType;
1313 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1314 if (quantAttr) {
1315 finalOutputType =
1316 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1317 }
1318 result.addTypes(finalOutputType);
1319}
1320
1321/// The tosa.matmul op is also intended to be generated where a fully_connected
1322/// op must be constructed where the weight is not a constant. In this case,
1323/// the fully_connected op must be expressed using matmul.
1324/// TODO: Add link to the leglization document explaining this.
1326 OperationState &result, Type outputType,
1327 Value a, Value b) {
1328 auto zps = createZPsAsConst(builder, a, b);
1329 result.addOperands({a, b, zps.first, zps.second});
1330
1331 Type finalOutputType{outputType};
1332 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1333 auto eType = getStorageElementTypeOrSelf(a.getType());
1334 auto inputBits = eType.getIntOrFloatBitWidth();
1335
1336 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1337 assert(outputShapedType && "Output must be a shaped type");
1338
1339 IntegerType accElementType;
1340 if (inputBits == 16)
1341 accElementType = builder.getIntegerType(48);
1342 else
1343 accElementType = builder.getI32Type();
1344
1345 finalOutputType = outputShapedType.clone(accElementType);
1346 }
1347 result.addTypes(finalOutputType);
1348}
1349
1350/// Both the tosa.avg_pool2d and unary ops use the same
1351/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1352/// has additional parameters not part of the unary ops.
1353static void
1355 Type outputType, Value input,
1356 DenseArrayAttr kernel, DenseArrayAttr stride,
1357 DenseArrayAttr pad, TypeAttr accType) {
1358 const Location loc{result.location};
1359 int64_t inputZp{0};
1360 int64_t outputZp{0};
1361
1362 if (auto quantAttr =
1363 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1364 inputZp = quantAttr.getInputZp();
1365 outputZp = quantAttr.getOutputZp();
1366 }
1367 const std::optional<Value> inputZpOp =
1368 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1369 if (!inputZpOp) {
1370 (void)emitError(
1371 loc,
1372 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1373 }
1374 const std::optional<Value> outputZpOp =
1375 createZeroPointTensor(builder, loc, outputType, outputZp);
1376 if (!outputZpOp) {
1377 (void)emitError(loc, "Failed to create output zero point tensor for "
1378 "quantized AVG_POOL2D op");
1379 }
1380
1381 if (inputZpOp && outputZpOp) {
1382 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1383 } else {
1384 // failed to create one or more zero points above: just add input as
1385 // operands this will trigger error in building the op because of missing
1386 // zero points
1387 result.addOperands({input});
1388 }
1389 result.addAttribute("kernel", kernel);
1390 result.addAttribute("stride", stride);
1391 result.addAttribute("pad", pad);
1392 result.addAttribute("acc_type", accType);
1393 result.types.push_back(outputType);
1394}
1395
1396/// This builder is called on single-parameter negate operator
1397/// to construct input and output zero points based on their
1398/// types.
1400 OperationState &result, Type outputType,
1401 Value input) {
1402 const Location loc{result.location};
1403 int64_t input1Zp{0};
1404 int64_t outputZp{0};
1405 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1406 if (quantAttr) {
1407 input1Zp = quantAttr.getInputZp();
1408 outputZp = quantAttr.getOutputZp();
1409 }
1410 const std::optional<Value> input1ZpOp =
1411 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1412 if (!input1ZpOp) {
1413 (void)emitError(
1414 loc, "Failed to create input1 zero point for quantized NEGATE op");
1415 }
1416
1417 const std::optional<Value> outputZpOp =
1418 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1419 if (!outputZpOp) {
1420 (void)emitError(
1421 loc, "Failed to create output zero point for quantized NEGATE op");
1422 }
1423
1424 if (input1ZpOp && outputZpOp) {
1425 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1426 } else {
1427 // failed to create one or more zero points above: just add input as
1428 // operands. This will trigger error in building the op because of
1429 // missing zero points
1430 result.addOperands({input});
1431 }
1432
1433 result.types.push_back(outputType);
1434}
1435
1436/// This builder is called on TOSA pad operator that needs to create its own
1437/// OptionalAttr quantization_attr parameter to scale the padding values
1438/// correctly. No pad_const is interpreted as zero-padding.
1440 Type outputType, Value input,
1441 Value paddings) {
1442 const Location loc{result.location};
1443 int32_t zp{0};
1444 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1445 if (quantAttr) {
1446 zp = static_cast<int32_t>(quantAttr.getInputZp());
1447 }
1448 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1449 result.addOperands({input, paddings, padConstOp});
1450 result.types.push_back(outputType);
1451}
1452
1454 StringRef name, Type variableType,
1455 Attribute initialValue) {
1456 const Location loc{result.location};
1457 auto nameAttr = builder.getStringAttr(name);
1458
1459 auto shapedType = dyn_cast<ShapedType>(variableType);
1460 if (!shapedType) {
1461 (void)emitError(loc, "variable type must be a shaped type");
1462 return;
1463 }
1464 if (!shapedType.hasRank()) {
1465 (void)emitError(loc, "variable type must be a ranked type");
1466 return;
1467 }
1468
1469 auto elementType = shapedType.getElementType();
1470 auto elementTypeAttr = TypeAttr::get(elementType);
1471 ArrayRef<int64_t> shape = shapedType.getShape();
1472 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1473
1474 result.addAttribute("sym_name", nameAttr);
1475 result.addAttribute("var_shape", varShapeAttr);
1476 result.addAttribute("type", elementTypeAttr);
1477 result.addAttribute("initial_value", initialValue);
1478}
1479
1480//===----------------------------------------------------------------------===//
1481// TOSA Operator Return Type Inference.
1482//===----------------------------------------------------------------------===//
1483
1484static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1485 SmallVector<int64_t> &outShape) {
1486 int64_t outRank = 0;
1487 for (int i = 0, e = operands.size(); i != e; ++i) {
1488 auto shape = operands.getShape(i);
1489 if (!shape.hasRank()) {
1490 // TODO(jennik): Update function to have better case handling for
1491 // invalid operands and for ranked tensors.
1492 return failure();
1493 }
1494 outRank = std::max<int64_t>(outRank, shape.getRank());
1495 }
1496
1497 outShape.resize(outRank, 1);
1498
1499 for (int i = 0, e = operands.size(); i != e; ++i) {
1500 auto shape = operands.getShape(i);
1501 auto rankDiff = outShape.size() - shape.getRank();
1502
1503 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1504 auto dim1 = outShape[i + rankDiff];
1505 auto dim2 = shape.getDimSize(i);
1506 auto resolvedDim = dim1;
1507
1508 if (dim1 == 1) {
1509 resolvedDim = dim2;
1510 } else if (dim2 == 1) {
1511 resolvedDim = dim1;
1512 } else if (dim1 != dim2) {
1513 return failure();
1514 }
1515 outShape[i + rankDiff] = resolvedDim;
1516 }
1517 }
1518
1519 return success();
1520}
1521
1522LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1523 MLIRContext *context, ::std::optional<Location> location,
1524 ArgMaxOp::Adaptor adaptor,
1525 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1526 ShapeAdaptor inputShape(adaptor.getInput().getType());
1527 IntegerAttr axis = adaptor.getProperties().axis;
1528 int32_t axisVal = axis.getValue().getSExtValue();
1529
1530 if (!inputShape.hasRank()) {
1531 inferredReturnShapes.push_back(ShapedTypeComponents());
1532 return success();
1533 }
1534
1535 SmallVector<int64_t> outShape;
1536 outShape.reserve(inputShape.getRank() - 1);
1537 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1538 if (i == axisVal)
1539 continue;
1540 outShape.push_back(inputShape.getDimSize(i));
1541 }
1542
1543 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1544 return success();
1545}
1546
1547LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1548 MLIRContext *context, ::std::optional<Location> location,
1549 RFFT2dOp::Adaptor adaptor,
1550 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1551 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1552
1553 if (!inputShape.hasRank())
1554 return failure();
1555
1556 llvm::SmallVector<int64_t> outputShape;
1557 outputShape.resize(3, ShapedType::kDynamic);
1558 outputShape[0] = inputShape.getDimSize(0);
1559 outputShape[1] = inputShape.getDimSize(1);
1560 int64_t inWidth = inputShape.getDimSize(2);
1561
1562 // Note that we can support this calculation symbolically
1563 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1564 if (inWidth != ShapedType::kDynamic)
1565 outputShape[2] = inWidth / 2 + 1;
1566
1567 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1568 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1569
1570 return success();
1571}
1572
1573static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1574 const llvm::StringRef dimName) {
1575 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1576 if (!isPowerOfTwo)
1577 return op->emitOpError("expected ")
1578 << dimName << " to be a power of two, got " << dimSize;
1579
1580 return success();
1581}
1582
1583LogicalResult tosa::RFFT2dOp::verify() {
1584 const auto outputTypes = getResultTypes();
1585 if (failed(verifyCompatibleShapes(outputTypes)))
1586 return emitOpError("expected output shapes to match, got ") << outputTypes;
1587
1588 const auto inputType =
1589 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1590 if (!inputType)
1591 return success();
1592
1593 const int64_t height = inputType.getDimSize(1);
1594 if (ShapedType::isStatic(height) &&
1595 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1596 return failure();
1597
1598 const int64_t width = inputType.getDimSize(2);
1599 if (ShapedType::isStatic(width) &&
1600 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1601 return failure();
1602
1603 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1604 if (!outputType)
1605 return success();
1606
1607 // Batch and height input/output dimensions should match
1608 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1609 outputType.getShape().drop_back())))
1610 return emitOpError("expected batch and height dimensions of input/output "
1611 "to match, got input=")
1612 << inputType << " output=" << outputType;
1613
1614 // Output width dimension expected to be input_width / 2 + 1
1615 const int64_t outputWidth = outputType.getDimSize(2);
1616 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1617 (outputWidth != (width / 2) + 1))
1618 return emitOpError(
1619 "expected output width to be equal to input_width / 2 + 1, got ")
1620 << outputWidth;
1621
1622 return success();
1623}
1624
1625LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1626 MLIRContext *context, ::std::optional<Location> location,
1627 FFT2dOp::Adaptor adaptor,
1628 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1629 inferredReturnShapes.push_back(
1630 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1631 inferredReturnShapes.push_back(
1632 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1633 return success();
1634}
1635
1636LogicalResult tosa::FFT2dOp::verify() {
1637 const auto inputRealType =
1638 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1639 const auto inputImagType =
1640 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1641 if (!inputRealType || !inputImagType)
1642 return success();
1643
1644 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1645 return ShapedType::isDynamic(a) ? a : b;
1646 };
1647
1648 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1649 inputImagType.getDimSize(1));
1650 if (ShapedType::isStatic(height) &&
1651 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1652 return failure();
1653
1654 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1655 inputImagType.getDimSize(2));
1656 if (ShapedType::isStatic(width) &&
1657 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1658 return failure();
1659
1660 return success();
1661}
1662
1663LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1664 MLIRContext *context, ::std::optional<Location> location,
1665 ConcatOp::Adaptor adaptor,
1666 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1667 // Infer all dimension sizes by reducing based on inputs.
1668 const Properties &prop = adaptor.getProperties();
1669 int32_t axis = prop.axis.getValue().getSExtValue();
1670 llvm::SmallVector<int64_t> outputShape;
1671 bool hasRankedInput = false;
1672 for (auto operand : adaptor.getOperands()) {
1673 ShapeAdaptor operandShape(operand.getType());
1674 if (!operandShape.hasRank())
1675 continue;
1676
1677 // Copy the Operand's rank.
1678 if (!hasRankedInput)
1679 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1680
1681 // Copy shapes until the dim is non-dynamic.
1682 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1683 if (i == axis || operandShape.isDynamicDim(i))
1684 continue;
1685 if (outputShape[i] == ShapedType::kDynamic)
1686 outputShape[i] = operandShape.getDimSize(i);
1687 if (outputShape[i] != operandShape.getDimSize(i))
1688 return emitOptionalError(location,
1689 "Cannot concat tensors with different sizes"
1690 " on the non-axis dimension ",
1691 i);
1692 }
1693
1694 hasRankedInput = true;
1695 }
1696
1697 if (adaptor.getInput1().empty())
1698 return failure();
1699
1700 Type inputType =
1701 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1702 if (!hasRankedInput) {
1703 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1704 return success();
1705 }
1706
1707 // Determine the dimension size along the concatenation axis.
1708 int64_t concatDimSize = 0;
1709 for (auto operand : adaptor.getOperands()) {
1710 ShapeAdaptor operandShape(operand.getType());
1711
1712 // We need to know the length of the concatenation axis of all inputs to
1713 // determine the dimension size of the output shape.
1714 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1715 concatDimSize = ShapedType::kDynamic;
1716 break;
1717 }
1718
1719 concatDimSize += operandShape.getDimSize(axis);
1720 }
1721
1722 outputShape[axis] = concatDimSize;
1723
1724 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1725 return success();
1726}
1727
1728LogicalResult tosa::ConcatOp::verify() {
1729 // check that each input has same element type as output
1730 auto outType = getOutput().getType();
1731 const Operation::operand_range inputList = getInput1();
1732
1733 // Check there is at least one input
1734 if (inputList.empty())
1735 return emitOpError("expect at least one input");
1736
1737 if (!llvm::all_of(inputList, [&](auto input) {
1738 return succeeded(verifySameElementTypes(
1739 *this, /* inType = */ input.getType(), outType));
1740 })) {
1741 return failure();
1742 }
1743
1744 const int32_t axis = getAxis();
1745 ShapeAdaptor firstRankedInputShape = nullptr;
1746 for (const auto &input : inputList) {
1747 const Type inputType = input.getType();
1748 ShapeAdaptor currShape(inputType);
1749 if (currShape.hasRank()) {
1750 firstRankedInputShape = currShape;
1751 // Check axis is in expected range
1752 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1753 return emitOpError("expect axis to be within range 0 < axis < "
1754 "rank(input1[firstRankedTensorIdx]), got ")
1755 << axis;
1756 break;
1757 }
1758 }
1759
1760 const auto allOperandsHasRank = [](const Value input) {
1761 return ShapeAdaptor(input.getType()).hasRank();
1762 };
1763 if (llvm::all_of(inputList, allOperandsHasRank)) {
1764 const int64_t firstInputRank = firstRankedInputShape.getRank();
1765
1766 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1767 const ShapeAdaptor inputShape(input.getType());
1768 const int64_t inputRank = inputShape.getRank();
1769 const size_t operandNum = index + 1;
1770
1771 // Check that each operand has the same rank
1772 if (inputRank != firstInputRank)
1773 return emitOpError(
1774 "expect all operands to have the same rank, but got ")
1775 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1776 << operandNum;
1777
1778 // Check non-axis dims match
1779 for (int i = 0; i < inputRank; i++) {
1780 const int64_t inputDim = inputShape.getDimSize(i);
1781 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1782 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1783 inputShape.isDynamicDim(i))
1784 continue;
1785 if (inputDim != firstInputDim)
1786 return emitOpError("expect all operand shapes to have the same sizes "
1787 "on non-axis dimensions, but got ")
1788 << inputDim << " vs " << firstInputDim << " at index " << i
1789 << " on operands 0 and " << operandNum;
1790 }
1791 }
1792
1793 const ShapeAdaptor outputShape(outType);
1794 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1795 return emitOpError("expect output rank to match inputs rank, got ")
1796 << outputShape.getRank() << " vs " << firstInputRank;
1797
1798 // ERROR_IF(axis_sum != shape[axis]);
1799 int64_t axisSum = 0;
1800 for (const auto &input : inputList) {
1801 const ShapeAdaptor inputShape(input.getType());
1802 if (inputShape.isDynamicDim(axis)) {
1803 // make axisSum negative to indicate invalid value
1804 axisSum = -1;
1805 break;
1806 }
1807 axisSum += inputShape.getDimSize(axis);
1808 }
1809
1810 if (axisSum >= 0 && outputShape.hasRank() &&
1811 !outputShape.isDynamicDim(axis) &&
1812 axisSum != outputShape.getDimSize(axis))
1813 return emitOpError("requires sum of axis dimensions of input1 "
1814 "equal to output axis dimension, got ")
1815 << axisSum << " and " << outputShape.getDimSize(axis);
1816 }
1817
1818 return success();
1819}
1820
1821LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1822 MLIRContext *context, ::std::optional<Location> location,
1823 ValueShapeRange operands, DictionaryAttr attributes,
1824 OpaqueProperties properties, RegionRange regions,
1825 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1826 auto elementType = IntegerType::get(context, /*width=*/1);
1827
1829 if (resolveBroadcastShape(operands, outShape).failed()) {
1830 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1831 return success();
1832 }
1833
1834 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1835 return success();
1836}
1837
1838bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1839 if (l.size() != r.size() || l.size() != 1)
1840 return false;
1841 return succeeded(verifyCompatibleShape(l[0], r[0]));
1842}
1843
1844LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1845 MLIRContext *context, ::std::optional<Location> location,
1846 MatMulOp::Adaptor adaptor,
1847 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1848 ShapeAdaptor lhsShape(adaptor.getA().getType());
1849 ShapeAdaptor rhsShape(adaptor.getB().getType());
1850
1851 // All shapes are dynamic.
1852 SmallVector<int64_t> outShape;
1853 outShape.resize(3, ShapedType::kDynamic);
1854
1855 if (lhsShape.hasRank()) {
1856 outShape[0] = lhsShape.getDimSize(0);
1857 outShape[1] = lhsShape.getDimSize(1);
1858 }
1859
1860 if (rhsShape.hasRank()) {
1861 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1862 : outShape[0];
1863 outShape[2] = rhsShape.getDimSize(2);
1864 }
1865
1866 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1867 return success();
1868}
1869
1870LogicalResult MatMulOp::verify() {
1871 auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1872 auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1873
1874 // Must be shaped tensor types
1875 if (!aType)
1876 return emitOpError("expect a shaped tensor for input a, got ")
1877 << getA().getType();
1878
1879 if (!bType)
1880 return emitOpError("expect a shaped tensor for input b, got ")
1881 << getB().getType();
1882
1883 auto aElementType = aType.getElementType();
1884 auto bElementType = bType.getElementType();
1885
1886 auto aQuantizedEType =
1887 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1888 auto bQuantizedEType =
1889 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1890
1891 if (aQuantizedEType || bQuantizedEType) {
1892 if (!aQuantizedEType || !bQuantizedEType) {
1893 return emitOpError("expect operands to be both quantized or both not "
1894 "quantized, got ")
1895 << aElementType << " and " << bElementType;
1896 }
1897 // both a and b have quantized element types
1898 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1899 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1900 if (aQuantWidth != bQuantWidth) {
1901 return emitOpError("expect quantized operands to have same widths, got ")
1902 << aQuantWidth << " and " << bQuantWidth;
1903 }
1904 }
1905
1906 // check a_zp and b_zp
1907 auto aEType = getStorageElementTypeOrSelf(aType);
1908 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1909 if (aEType != aZpEType) {
1910 return emitOpError("expect input a and a_zp have the same "
1911 "element type, got ")
1912 << aEType << " and " << aZpEType;
1913 }
1914
1915 auto bEType = getStorageElementTypeOrSelf(bType);
1916 auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1917 if (bEType != bZpEType) {
1918 return emitOpError("expect input b and b_zp have the same "
1919 "element type, got ")
1920 << bEType << " and " << bZpEType;
1921 }
1922
1923 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1924 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1925 return failure();
1926
1927 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1928 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1929 return failure();
1930
1931 return success();
1932}
1933
1934LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1935 MLIRContext *context, ::std::optional<Location> location,
1936 MatmulTBlockScaledOp::Adaptor adaptor,
1937 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1938 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
1939
1940 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1941 if (aDataShape.hasRank()) {
1942 outShape[0] = aDataShape.getDimSize(0);
1943 outShape[1] = aDataShape.getDimSize(1);
1944 }
1945
1946 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1947 if (aScaleShape.hasRank()) {
1948 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1949 : outShape[0];
1950 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1951 : outShape[1];
1952 }
1953
1954 // If B batch size is 1, it is broadcast across A's batch size
1955 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1956 if (bDataShape.hasRank()) {
1957 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1958 if (bDataBatchSize != 1)
1959 outShape[0] =
1960 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1961 outShape[2] = bDataShape.getDimSize(1);
1962 }
1963
1964 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1965 if (bScaleShape.hasRank()) {
1966 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1967 if (bScaleBatchSize != 1)
1968 outShape[0] =
1969 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1970 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1971 : outShape[2];
1972 }
1973
1974 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1975 return success();
1976}
1977
1978LogicalResult MatmulTBlockScaledOp::verify() {
1979 // Verify same input data types
1980 const Type aDataType = getAData().getType();
1981 const Type bDataType = getBData().getType();
1982 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
1983 "B_data")))
1984 return failure();
1985
1986 // Verify input shape compatibility
1987 int64_t N = ShapedType::kDynamic;
1988 int64_t D = ShapedType::kDynamic;
1989 int64_t H = ShapedType::kDynamic;
1990 int64_t W = ShapedType::kDynamic;
1991 int64_t C = ShapedType::kDynamic;
1992 int64_t multiplesOfC = ShapedType::kDynamic;
1993
1994 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
1995 if (aDataShape.hasRank()) {
1996 N = aDataShape.getDimSize(0);
1997 H = aDataShape.getDimSize(1);
1998 C = aDataShape.getDimSize(2);
1999 }
2000
2001 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
2002 if (aScaleShape.hasRank()) {
2003 if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
2004 "a_scale", "batch")) ||
2005 failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
2006 "a_scale", "height")))
2007 return failure();
2008 multiplesOfC = aScaleShape.getDimSize(2);
2009 }
2010
2011 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
2012 if (bDataShape.hasRank()) {
2013 if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
2014 "b_data", "batch")) ||
2015 failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
2016 "b_data", "channels")))
2017 return failure();
2018 W = bDataShape.getDimSize(1);
2019 }
2020
2021 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2022 if (bScaleShape.hasRank()) {
2023 if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
2024 "b_scale", "batch")) ||
2025 failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
2026 "b_scale", "width")) ||
2027 failed(tryUpdateDimOrFailure(*this, multiplesOfC,
2028 bScaleShape.getDimSize(2), "b_scale",
2029 "C/block_size")))
2030 return failure();
2031 }
2032
2033 // Verify batch size is broadcast compatible
2034 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2035 return emitOpError("expect B matrix batch size to be broadcast compatible "
2036 "with A, got D=")
2037 << D << " vs N=" << N;
2038
2039 // Verify C is a multiple of block size
2040 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2041 if (ShapedType::isStatic(C) && C % blockSize != 0)
2042 return emitOpError("expect C to be a multiple of block size, got C=")
2043 << C << ", block_size=" << blockSize;
2044
2045 // Verify multiplesOfC is C / block size
2046 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2047 multiplesOfC != C / blockSize)
2048 return emitOpError(
2049 "expect scale operands dimension 2 to equal C/block_size (")
2050 << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
2051
2052 // Verify output shape
2053 N = ShapedType::isDynamic(N) ? D : N;
2054 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2055 const auto outputType = cast<ShapedType>(getResult().getType());
2056 if (outputType.hasRank() &&
2057 failed(
2058 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2059 InFlightDiagnostic opError = emitOpError("expected output shape ");
2060 auto stringifyDim = [&](int64_t d) {
2061 if (ShapedType::isDynamic(d))
2062 opError << "?";
2063 else
2064 opError << d;
2065 };
2066 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2067 opError << " to be compatible with expected output shape ";
2068 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2069 return opError;
2070 }
2071
2072 return success();
2073}
2074
2075LogicalResult tosa::PadOp::inferReturnTypeComponents(
2076 MLIRContext *context, ::std::optional<Location> location,
2077 PadOp::Adaptor adaptor,
2078 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2079 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2080 auto paddingRank =
2081 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2082 SmallVector<int64_t> outputShape;
2083
2084 // If the input rank is unknown, we can infer the output rank using the
2085 // padding shape's rank divided by 2.
2086 if (!inputShape.hasRank()) {
2087 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2088 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2089 return success();
2090 }
2091
2092 SmallVector<int64_t> paddingValues;
2093 // If the paddings value is not a constant, all dimensions must be dynamic.
2094 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2095 paddingValues)) {
2096 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2097 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2098 return success();
2099 }
2100
2101 outputShape.reserve(inputShape.getRank());
2102 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2103 if (inputShape.isDynamicDim(i)) {
2104 outputShape.push_back(ShapedType::kDynamic);
2105 continue;
2106 }
2107 auto padFront = paddingValues[i * 2];
2108 auto padBack = paddingValues[i * 2 + 1];
2109 if (padFront < 0 || padBack < 0) {
2110 // if either padding for dim i is -1, output dim is unknown
2111 outputShape.push_back(ShapedType::kDynamic);
2112 continue;
2113 }
2114
2115 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2116 }
2117
2118 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2119 return success();
2120}
2121
2122LogicalResult tosa::PadOp::verify() {
2123 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2124 /* outType = */ getOutput().getType())
2125 .failed()) {
2126 return failure();
2127 }
2128
2129 if (auto padConst = getPadConst()) {
2130 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2131 /* outType = */ getOutput().getType())
2132 .failed()) {
2133 return failure();
2134 }
2135 }
2136
2137 RankedTensorType inputType =
2138 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2139 RankedTensorType outputType =
2140 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2141 if (!inputType || !outputType)
2142 return success();
2143
2144 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2145 "output")))
2146 return failure();
2147
2148 auto inputRank = inputType.getRank();
2149 DenseIntElementsAttr paddingAttr;
2150 if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
2151 return failure();
2152 }
2153
2154 auto paddingValues = paddingAttr.getValues<APInt>();
2155 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2156 return emitOpError() << "padding tensor must have " << inputRank
2157 << " * 2 = " << inputRank * 2 << " elements, but got "
2158 << paddingValues.size();
2159
2160 auto inputShape = inputType.getShape();
2161 auto outputShape = outputType.getShape();
2162
2163 for (int64_t i = 0; i < inputRank; ++i) {
2164 int64_t padStart = paddingValues[i * 2].getSExtValue();
2165 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2166
2167 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2168 return emitOpError()
2169 << "invalid padding values at dimension " << i
2170 << ": values must be non-negative or -1 for dynamic padding, got ["
2171 << padStart << ", " << padEnd << "]";
2172 }
2173
2174 // Skip shape verification for dynamic input/output
2175 if (inputShape[i] == ShapedType::kDynamic ||
2176 outputShape[i] == ShapedType::kDynamic)
2177 continue;
2178
2179 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2180 return emitOpError() << "mismatch in output shape at dimension " << i
2181 << ": expected " << inputShape[i] << " + "
2182 << padStart << " + " << padEnd << " = "
2183 << (inputShape[i] + padStart + padEnd)
2184 << ", but got " << outputShape[i];
2185 }
2186 }
2187
2188 return success();
2189}
2190
2191LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2192 MLIRContext *context, ::std::optional<Location> location,
2193 SliceOp::Adaptor adaptor,
2194 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2195
2196 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2199
2200 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2201 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2202 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2203 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2204 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2205 return success();
2206 }
2207
2208 // if size[i] is -1, all remaining elements in dimension i are included
2209 // in the slice, similar to TF.
2210 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2211 // initialize outputShape to all unknown
2212 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2213 if (inputShape.hasRank()) {
2214 for (size_t i = 0; i < size.size(); i++) {
2215 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2216 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2217 start[i] < inputShape.getDimSize(i))) {
2218 // size[i] is not 0 and not < -1, and start[i] is in valid range
2219 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2220 // input shape has unknown dim[i] - only valid if size[i] > 0
2221 if (size[i] > 0) {
2222 outputShape[i] = size[i];
2223 }
2224 } else {
2225 // input shape has known dim[i]
2226 if (size[i] == -1) {
2227 outputShape[i] = inputShape.getDimSize(i) - start[i];
2228 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2229 // start[i] + size[i] is within bound of input shape's dim[i]
2230 outputShape[i] = size[i];
2231 }
2232 }
2233 }
2234 }
2235 } else {
2236 outputShape = convertToMlirShape(size);
2237 }
2238 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2239 return success();
2240}
2241
2242LogicalResult tosa::SliceOp::verify() {
2243 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2244 /* outType = */ getOutput().getType())
2245 .failed())
2246 return failure();
2247
2248 const ShapeAdaptor inputShape(getInput1().getType());
2249 if (inputShape.hasRank()) {
2250 const auto inputRank = inputShape.getRank();
2251 const ShapeAdaptor outputShape(getOutput().getType());
2252 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2253 return emitOpError(
2254 "expect input1 and output to have the same ranks, got ")
2255 << inputRank << " and " << outputShape.getRank();
2256
2257 const auto startShapeRank =
2258 llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
2259 if (inputRank != startShapeRank)
2260 return emitOpError("length of start is not equal to rank of input shape");
2261
2262 const auto sizeShapeRank =
2263 llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
2264 if (inputRank != sizeShapeRank)
2265 return emitOpError("length of size is not equal to rank of input shape");
2266 }
2267
2268 return success();
2269}
2270
2271LogicalResult tosa::MulOp::inferReturnTypeComponents(
2272 MLIRContext *context, ::std::optional<Location> location,
2273 ValueShapeRange operands, DictionaryAttr attributes,
2274 OpaqueProperties properties, RegionRange regions,
2275 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2276 // mul op's output shape only depend on input1 and input2, not on shift
2277 ValueShapeRange twoInputs = operands.drop_back();
2279 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2280 inferredReturnShapes.push_back(ShapedTypeComponents());
2281 } else {
2282 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2283 }
2284 return success();
2285}
2286
2287LogicalResult tosa::MulOp::verify() {
2288 const Value output = getOutput();
2289 auto resElemType = getElementTypeOrSelf(output);
2290
2291 // Verify if the element type among operands and result match tosa
2292 // specification.
2293 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2294 IntegerType lhsIntType =
2295 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2296 IntegerType rhsIntType =
2297 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2298 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2299 return emitOpError("requires the same element type for all operands");
2300
2301 // Though the spec requires the element type of result to be i32, a more
2302 // relaxed way is provided at dialect level for easier cooperating with
2303 // other dialects.
2304 if (lhsIntType.getWidth() > resIntType.getWidth())
2305 return emitOpError("invalid data type size for operands or result");
2306
2307 } else {
2308 // For other supported type, the spec requires requires the same element
2309 // type for all operands (excludes `shift` operand) and results.
2310 for (int i = 0; i < 2; ++i) {
2311 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2312 return emitOpError(
2313 "requires the same element type for all operands and results");
2314 }
2315
2316 // verify shift has value 0 for non-integer types
2317 ElementsAttr shiftElem;
2318 if (matchPattern(getShift(), m_Constant(&shiftElem))) {
2319 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2320 if (shift != 0) {
2321 return emitOpError() << "require shift to be 0 for float type";
2322 }
2323 }
2324 }
2325
2326 // Verify the op has same ranks for all main operands (excludes extra operands
2327 // such as shift of mul op, so this is the only difference with the built-in
2328 // `SameOperandsAndResultRank` trait) and results types, if known.
2329 TypeRange operandTypes = getOperandTypes();
2330 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2331 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2332
2333 const bool aHasRank = aType.hasRank();
2334 const bool bHasRank = bType.hasRank();
2335 if (aHasRank && bHasRank) {
2336 const int64_t aRank = aType.getRank();
2337 const int64_t bRank = bType.getRank();
2338 if (aRank != bRank)
2339 return emitOpError("a and b operands don't have matching ranks, got ")
2340 << aRank << " and " << bRank;
2341
2342 // check for broadcast compatible shapes
2343 SmallVector<int64_t> resultShape;
2345 aType.getShape(), bType.getShape(), resultShape))
2346 return emitOpError("a and b operands don't have broadcast-compatible "
2347 "shapes, got ")
2348 << aType << " and " << bType;
2349 }
2350
2351 ShapedType resultType = cast<ShapedType>(output.getType());
2352 if (!resultType.hasRank())
2353 return success();
2354
2355 const int64_t resultRank = resultType.getRank();
2356 if (aHasRank && resultRank != aType.getRank())
2357 return emitOpError("result type has different rank than a, got ")
2358 << resultRank << " vs " << aType.getRank();
2359 if (bHasRank && resultRank != bType.getRank())
2360 return emitOpError("result type has different rank than b, got ")
2361 << resultRank << " vs " << bType.getRank();
2362
2363 return success();
2364}
2365
2366LogicalResult tosa::TableOp::inferReturnTypeComponents(
2367 MLIRContext *context, ::std::optional<Location> location,
2368 TableOp::Adaptor adaptor,
2369 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2370 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2371
2372 if (!inputShape.hasRank()) {
2373 inferredReturnShapes.push_back(ShapedTypeComponents());
2374 return success();
2375 }
2376
2377 inferredReturnShapes.resize(1);
2378 inputShape.getDims(inferredReturnShapes[0]);
2379 return success();
2380}
2381
2382LogicalResult tosa::TableOp::verify() {
2383 const TensorType inputType = getInput1().getType();
2384 const TensorType outputType = getOutput().getType();
2385
2386 if (!inputType.hasRank() || !outputType.hasRank())
2387 return success();
2388
2389 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2390 "result")))
2391 return failure();
2392
2393 auto inputDims = inputType.getShape();
2394 auto outputDims = outputType.getShape();
2395 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2396 int64_t dim = it.index();
2397 auto [inputDim, outputDim] = it.value();
2398 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2399 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2400 << " doesn't match dim(input, " << dim
2401 << ") = " << inputDim;
2402 }
2403 }
2404 return success();
2405}
2406
2407LogicalResult
2408tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2409 // Multiples must be constants.
2410 DenseIntElementsAttr multiplesAttr;
2411 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2412 return failure();
2413 multiples = llvm::to_vector(
2414 llvm::map_range(multiplesAttr.getValues<APInt>(),
2415 [](const APInt &val) { return val.getSExtValue(); }));
2416 return success();
2417}
2418
2419LogicalResult tosa::TileOp::inferReturnTypeComponents(
2420 MLIRContext *context, ::std::optional<Location> location,
2421 TileOp::Adaptor adaptor,
2422 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2423 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2424 SmallVector<int64_t> multiples;
2425 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2426 multiples)) {
2427 auto rank =
2428 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2429 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2430 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2431 return success();
2432 } else {
2433 multiples = convertToMlirShape(multiples);
2434 }
2435
2436 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2437 SmallVector<int64_t> outputShape;
2438 if (!inputShape.hasRank()) {
2439 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2440 inferredReturnShapes.push_back(
2441 ShapedTypeComponents(outputShape, inputType));
2442 return success();
2443 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2444 return failure();
2445
2446 // Any non dynamic dimension can be multiplied to a known size.
2447 outputShape.reserve(multiples.size());
2448 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2449 if (multiples[i] == ShapedType::kDynamic) {
2450 outputShape.push_back(ShapedType::kDynamic);
2451 } else {
2452 int64_t dim = inputShape.getDimSize(i);
2453 if (dim != ShapedType::kDynamic)
2454 dim *= multiples[i];
2455 outputShape.push_back(dim);
2456 }
2457 }
2458
2459 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2460 return success();
2461}
2462
2463LogicalResult tosa::TileOp::verify() {
2464 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2465 /* outType = */ getOutput().getType())
2466 .failed()) {
2467 return failure();
2468 }
2469 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2470 ShapedType outputType = llvm::cast<ShapedType>(getType());
2471
2472 shapeType multiplesType =
2473 llvm::cast<tosa::shapeType>(getMultiples().getType());
2474
2475 auto multiplesRank = multiplesType.getRank();
2476
2477 if (inputType.hasRank()) {
2478 if (inputType.getRank() != multiplesRank)
2479 return emitOpError("expect 'multiples' to have rank ")
2480 << inputType.getRank() << " but got " << multiplesRank << ".";
2481 if (outputType.hasRank() &&
2482 failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2483 "output")))
2484 return failure();
2485 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2486 return emitOpError("expect 'multiples' array to have length ")
2487 << outputType.getRank() << " but got " << multiplesRank << ".";
2488
2489 SmallVector<int64_t> multiples;
2490 if (getConstantMultiples(multiples).succeeded() &&
2491 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2492 return emitOpError(
2493 "expect element of 'multiples' to be positive integer or -1.");
2494
2495 return success();
2496}
2497
2498bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2499 if (l.size() != r.size() || l.size() != 1)
2500 return false;
2501 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2502}
2503
2504LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2505 MLIRContext *context, ::std::optional<Location> location,
2506 ReshapeOp::Adaptor adaptor,
2507 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2508 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2509 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2510 llvm::SmallVector<int64_t> newShapeValue;
2511 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2512 newShapeValue)) {
2513 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2514 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2515 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2516 return success();
2517 } else {
2518 newShapeValue = convertToMlirShape(newShapeValue);
2519 }
2520
2521 // We cannot infer from the total number of elements so we must take the
2522 // shape attribute as exact.
2523 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2524 inferredReturnShapes.push_back(
2525 ShapedTypeComponents(newShapeValue, inputType));
2526 return success();
2527 }
2528
2529 // Determine the number of elements covered by the slice of all static
2530 // dimensions. This allows us to infer the length of the remaining dynamic
2531 // dimension.
2532 int64_t numElements = inputShape.getNumElements();
2533 int64_t staticMul = 1;
2534 for (auto val : newShapeValue) {
2535 if (ShapedType::isStatic(val)) {
2536 staticMul *= val;
2537 }
2538 }
2539
2540 // Determine the length of the dynamic dimension.
2541 for (auto &val : newShapeValue) {
2542 if (ShapedType::isDynamic(val))
2543 val = numElements / staticMul;
2544 }
2545
2546 inferredReturnShapes.push_back(
2547 ShapedTypeComponents(newShapeValue, inputType));
2548 return success();
2549}
2550
2551llvm::LogicalResult tosa::ReshapeOp::verify() {
2552 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2553 /* outType = */ getOutput().getType())
2554 .failed()) {
2555 return failure();
2556 }
2557 TensorType inputType = getInput1().getType();
2558
2559 SmallVector<int64_t> shapeValues;
2560 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2561 // skip following checks if shape is not constant
2562 return mlir::success();
2563 }
2564
2565 int missingDims = llvm::count(shapeValues, -1);
2566 if (missingDims > 1)
2567 return emitOpError() << "expected at most one target dimension to be -1";
2568
2569 const auto outputType = dyn_cast<RankedTensorType>(getType());
2570 if (!outputType)
2571 return success();
2572
2573 if ((int64_t)shapeValues.size() != outputType.getRank())
2574 return emitOpError() << "new shape does not match result rank";
2575
2576 for (auto [newShapeDim, outputShapeDim] :
2577 zip(shapeValues, outputType.getShape())) {
2578 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2579 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2580 return emitOpError() << "new shape is inconsistent with result shape";
2581
2582 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2583 return emitOpError() << "new shape has invalid tensor dimension size "
2584 << newShapeDim;
2585 }
2586
2587 if (inputType.hasStaticShape()) {
2588 int64_t inputElementsNum = inputType.getNumElements();
2589 if (outputType.hasStaticShape()) {
2590 int64_t outputElementsNum = outputType.getNumElements();
2591 if (inputElementsNum != outputElementsNum) {
2592 return emitOpError() << "cannot reshape " << inputElementsNum
2593 << " elements into " << outputElementsNum;
2594 }
2595 }
2596
2597 int64_t newShapeElementsNum =
2598 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2599 return (dim > 0) ? acc * dim : acc;
2600 });
2601 bool isStaticNewShape =
2602 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2603 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2604 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2605 return emitOpError() << "cannot reshape " << inputElementsNum
2606 << " elements into " << newShapeElementsNum;
2607 }
2608 }
2609
2610 return mlir::success();
2611}
2612
2613// return failure if val is not a constant
2614// set zp to -1 if val is non-zero float or val is not integer nor float
2615// otherwise set zp to val's constant value
2616static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2617 ElementsAttr zpAttr;
2618 if (!matchPattern(val, m_Constant(&zpAttr))) {
2619 return failure();
2620 }
2621
2622 Type zpElemType = zpAttr.getElementType();
2623
2624 if (llvm::isa<FloatType>(zpElemType)) {
2625 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2626 return 0;
2627 }
2628 // return non-zero value to trigger error check
2629 return -1;
2630 }
2631
2632 if (llvm::isa<IntegerType>(zpElemType)) {
2633 if (signExtend)
2634 return zpAttr.getValues<APInt>()[0].getSExtValue();
2635 else
2636 return zpAttr.getValues<APInt>()[0].getZExtValue();
2637 }
2638
2639 // return non-zero value to trigger error check
2640 return -1;
2641}
2642
2643template <typename T>
2644static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2645 const std::string &operand) {
2646 Type zpElemType = getElementTypeOrSelf(val);
2647
2648 if (!zpElemType.isInteger(8) && zp != 0) {
2649 // convert operand to lower case for error message
2650 std::string lower = operand;
2651 llvm::transform(lower, lower.begin(), ::tolower);
2652 return op.emitOpError()
2653 << lower << " zero point must be zero for non-int8 integer types";
2654 }
2655
2656 return success();
2657}
2658
2659static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2660 const int64_t &zp,
2661 const std::string &operand) {
2662 bool isInputZp = (operand == "Input");
2663
2664 bool tensorUnsigned =
2665 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2666 StringRef tensorName = isInputZp ? "input" : "output";
2667
2668 Type zpElemType = getElementTypeOrSelf(zpVal);
2669
2670 if (zp != 0) {
2671 if (!zpElemType.isInteger(8) &&
2672 !(zpElemType.isInteger(16) && tensorUnsigned)) {
2673 return op.emitOpError()
2674 << "expect " << tensorName << "_zp of 0, got " << zp;
2675 }
2676 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2677 return op.emitOpError() << "expect " << tensorName
2678 << "_zp of 0 or 32768 for unsigned int16 "
2679 << tensorName << ", got " << zp;
2680 }
2681 }
2682
2683 return success();
2684}
2685
2686#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2687 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2688 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2689 } \
2690 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2691 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2692 }
2693
2694ZERO_POINT_HELPER(Conv2DOp, Input, true)
2695ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2696ZERO_POINT_HELPER(Conv3DOp, Input, true)
2697ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2698ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2699ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2700ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2701ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2702ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2703ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2704ZERO_POINT_HELPER(MatMulOp, A, true)
2705ZERO_POINT_HELPER(MatMulOp, B, true)
2706ZERO_POINT_HELPER(NegateOp, Input1, true)
2707ZERO_POINT_HELPER(NegateOp, Output, true)
2708ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2709ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2710#undef ZERO_POINT_HELPER
2711
2712LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2713 MLIRContext *context, ::std::optional<Location> location,
2714 TransposeOp::Adaptor adaptor,
2715 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2716 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2717
2718 // If input rank and permutation length is unknown, the output rank is
2719 // unknown.
2720 if (!inputShape.hasRank()) {
2721 inferredReturnShapes.push_back(ShapedTypeComponents());
2722 return success();
2723 }
2724
2725 const auto inputRank = inputShape.getRank();
2726
2727 // This would imply the number of permutations does not match the rank of
2728 // the input which is illegal.
2729 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2730 return failure();
2731 }
2732
2733 SmallVector<int64_t> outputShape;
2734 // Rank-0 means no permutations matter.
2735 if (inputRank == 0) {
2736 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2737 return success();
2738 }
2739
2740 // Check whether the input dimensions are all the same.
2741 bool allTheSame = true;
2742 for (int i = 1, s = inputRank; i < s; i++) {
2743 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2744 allTheSame = false;
2745 break;
2746 }
2747 }
2748
2749 // If all of the input dimensions are the same we don't care about the
2750 // permutation.
2751 if (allTheSame) {
2752 outputShape.resize(inputRank, inputShape.getDimSize(0));
2753 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2754 return success();
2755 }
2756
2757 outputShape.resize(inputRank, ShapedType::kDynamic);
2758
2759 // Constant permutation values must be within the input rank.
2760 if (llvm::any_of(adaptor.getPerms(),
2761 [inputRank](const auto i) { return i >= inputRank; }))
2762 return failure();
2763
2764 outputShape.reserve(inputRank);
2765 for (int i = 0, s = inputRank; i < s; i++) {
2766 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2767 }
2768
2769 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2770 return success();
2771}
2772
2773LogicalResult tosa::TransposeOp::verify() {
2774 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2775 /* outType = */ getOutput().getType())
2776 .failed()) {
2777 return failure();
2778 }
2779
2780 const ShapeAdaptor inputShape(getInput1().getType());
2781 const ShapeAdaptor outputShape(getOutput().getType());
2782
2783 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2784
2785 if (inputShape.hasRank() &&
2786 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2787 return emitOpError() << "expected perms attribute to have size "
2788 << inputShape.getRank()
2789 << " (input rank) but got size "
2790 << constantPerms.size();
2791
2792 if (inputShape.hasRank() && outputShape.hasRank() &&
2793 inputShape.getRank() != outputShape.getRank())
2794 return emitOpError()
2795 << "expected input tensor rank to equal result tensor rank";
2796
2797 if (outputShape.hasRank() &&
2798 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2799 return emitOpError() << "expected perms attribute to have size "
2800 << outputShape.getRank()
2801 << " (output rank) but got size "
2802 << constantPerms.size();
2803
2804 if (!llvm::all_of(constantPerms,
2805 [&constantPerms](int32_t s) {
2806 return s >= 0 &&
2807 static_cast<size_t>(s) < constantPerms.size();
2808 }) ||
2809 !isPermutationVector(llvm::to_vector(llvm::map_range(
2810 constantPerms, [](int32_t v) -> int64_t { return v; }))))
2811 return emitOpError() << "expected valid permutation indices";
2812
2813 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2814 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2815 inputShape.getNumElements() != outputShape.getNumElements())
2816 return emitOpError() << "expected input1 and output to have same numbers "
2817 "of elements, got "
2818 << inputShape.getNumElements() << " and "
2819 << outputShape.getNumElements();
2820
2821 // Verify that the types of the input and output tensors are properly
2822 // permuted.
2823 if (inputShape.hasRank() && outputShape.hasRank()) {
2824 for (auto i = 0; i < outputShape.getRank(); i++) {
2825 if (inputShape.isDynamicDim(constantPerms[i]) ||
2826 outputShape.isDynamicDim(i))
2827 continue;
2828
2829 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2830 return emitOpError()
2831 << "expected output tensor dim " << i << " to match "
2832 << "input dim " << constantPerms[i] << " with value of "
2833 << inputShape.getDimSize(constantPerms[i]);
2834 }
2835 }
2836
2837 return success();
2838}
2839
2840LogicalResult TransposeOp::reifyResultShapes(
2841 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2842
2843 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2844
2845 Value input = getInput1();
2846 auto inputType = cast<TensorType>(input.getType());
2847
2848 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2849 for (auto dim : transposePerms) {
2850 int32_t dimInInput = transposePerms[dim];
2851 if (inputType.isDynamicDim(dimInInput))
2852 returnedDims[dim] =
2853 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2854 .getResult();
2855 else
2856 returnedDims[dim] =
2857 builder.getIndexAttr(inputType.getDimSize(dimInInput));
2858 }
2859
2860 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2861 return success();
2862}
2863
2864LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2865 MLIRContext *context, ::std::optional<Location> location,
2866 GatherOp::Adaptor adaptor,
2867 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2868 llvm::SmallVector<int64_t> outputShape;
2869 outputShape.resize(3, ShapedType::kDynamic);
2870
2871 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2872 if (valuesShape.hasRank()) {
2873 outputShape[0] = valuesShape.getDimSize(0);
2874 outputShape[2] = valuesShape.getDimSize(2);
2875 }
2876
2877 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2878 if (indicesShape.hasRank()) {
2879 if (outputShape[0] == ShapedType::kDynamic)
2880 outputShape[0] = indicesShape.getDimSize(0);
2881 if (outputShape[1] == ShapedType::kDynamic)
2882 outputShape[1] = indicesShape.getDimSize(1);
2883 }
2884
2885 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2886 return success();
2887}
2888
2889LogicalResult tosa::GatherOp::verify() {
2890 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2891 /* outType = */ getOutput().getType())
2892 .failed()) {
2893 return failure();
2894 }
2895
2896 const ShapeAdaptor valuesShape(getValues().getType());
2897 const ShapeAdaptor indicesShape(getIndices().getType());
2898 const ShapeAdaptor outputShape(getOutput().getType());
2899
2900 int64_t n = ShapedType::kDynamic;
2901 int64_t w = ShapedType::kDynamic;
2902 int64_t c = ShapedType::kDynamic;
2903
2904 if (valuesShape.hasRank()) {
2905 n = valuesShape.getDimSize(0);
2906 c = valuesShape.getDimSize(2);
2907 }
2908 if (indicesShape.hasRank()) {
2909 const int64_t indicesN = indicesShape.getDimSize(0);
2910 w = indicesShape.getDimSize(1);
2911 if (n == ShapedType::kDynamic)
2912 n = indicesN;
2913 else if (indicesN != ShapedType::kDynamic && n != indicesN)
2914 return emitOpError() << "requires indices dimension 0 to have size " << n
2915 << ", got " << indicesN;
2916 }
2917 if (outputShape.hasRank()) {
2918 const int64_t outputN = outputShape.getDimSize(0);
2919 const int64_t outputW = outputShape.getDimSize(1);
2920 const int64_t outputC = outputShape.getDimSize(2);
2921 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2922 n != outputN)
2923 return emitOpError() << "requires output dimension 0 to have size " << n
2924 << ", got " << outputN;
2925
2926 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2927 w != outputW)
2928 return emitOpError() << "requires output dimension 1 to have size " << w
2929 << ", got " << outputW;
2930 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2931 c != outputC)
2932 return emitOpError() << "requires output dimension 2 to have size " << c
2933 << ", got " << outputC;
2934 }
2935 return success();
2936}
2937
2938LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2939 MLIRContext *context, ::std::optional<Location> location,
2940 ResizeOp::Adaptor adaptor,
2941 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2942 llvm::SmallVector<int64_t, 4> outputShape;
2943 outputShape.resize(4, ShapedType::kDynamic);
2944
2945 ShapeAdaptor inputShape(adaptor.getInput().getType());
2946 if (!inputShape.hasRank())
2947 return failure();
2948
2949 outputShape[0] = inputShape.getDimSize(0);
2950 outputShape[3] = inputShape.getDimSize(3);
2951 int64_t inputHeight = inputShape.getDimSize(1);
2952 int64_t inputWidth = inputShape.getDimSize(2);
2953
2954 if ((inputHeight == ShapedType::kDynamic) ||
2955 (inputWidth == ShapedType::kDynamic))
2956 return failure();
2957
2958 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2959 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2960 scaleInt) ||
2961 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2962 offsetInt) ||
2963 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2964 borderInt)) {
2965 return failure();
2966 }
2967
2968 // Compute the output shape based on attributes: scale, offset, and border.
2969 const int64_t outputHeight =
2970 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2971 scaleInt[1]) +
2972 1;
2973
2974 const int64_t outputWidth =
2975 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2976 scaleInt[3]) +
2977 1;
2978
2979 if (outputHeight < 0 || outputWidth < 0) {
2980 return emitOptionalError(
2981 location,
2982 "calculated output height and width must be non-negative, "
2983 "got height = ",
2984 outputHeight, ", width = ", outputWidth);
2985 }
2986
2987 outputShape[1] = outputHeight;
2988 outputShape[2] = outputWidth;
2989 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2990 return success();
2991}
2992
2993LogicalResult tosa::ResizeOp::verify() {
2994 const Value input = getInput();
2995 const Value output = getOutput();
2996 const RankedTensorType inputType =
2997 llvm::dyn_cast<RankedTensorType>(input.getType());
2998 const RankedTensorType outputType =
2999 llvm::dyn_cast<RankedTensorType>(output.getType());
3000
3001 SmallVector<int64_t> scaleValues;
3002 SmallVector<int64_t> offsetValues;
3003 SmallVector<int64_t> borderValues;
3004 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
3005 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
3006 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
3007 // Skip following checks if shape is not constant
3008 return success();
3009 }
3010
3011 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
3012 return emitOpError("expect all scale values to be > 0, got ")
3013 << scaleValues;
3014
3015 const int64_t scaleYN = scaleValues[0];
3016 const int64_t scaleYD = scaleValues[1];
3017 const int64_t scaleXN = scaleValues[2];
3018 const int64_t scaleXD = scaleValues[3];
3019
3020 const int64_t offsetY = offsetValues[0];
3021 const int64_t offsetX = offsetValues[1];
3022
3023 const int64_t borderY = borderValues[0];
3024 const int64_t borderX = borderValues[1];
3025
3026 if (!inputType)
3027 return success();
3028 if (!outputType)
3029 return success();
3030
3031 const int64_t oh = outputType.getDimSize(1);
3032 const int64_t ow = outputType.getDimSize(2);
3033 const int64_t ih = inputType.getDimSize(1);
3034 const int64_t iw = inputType.getDimSize(2);
3035
3036 // Don't check with input height that could be broadcast (ih != 1)
3037 // since Linalg, a consumer of TOSA, expects broadcasting support
3038 // in resize to be available. Taking the cautious approach for now,
3039 // we can consider removing support for broadcasting later.
3040 if (ih != ShapedType::kDynamic && ih != 1) {
3041 const std::optional<int64_t> calculatedOutHeightMinusOne =
3042 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3043 if (!calculatedOutHeightMinusOne.has_value())
3044 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3045 "border_y ")
3046 << "to be wholly divisible by scale_y_d, got ((" << ih
3047 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3048 << ") / " << scaleYD;
3049 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3050 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3051 return emitOpError("calculated output height did not match expected: ")
3052 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3053 }
3054
3055 // Don't check with input width that could be broadcast (iw != 1)
3056 // since Linalg, a consumer of TOSA, expects broadcasting support
3057 // in resize to be available. Taking the cautious approach for now,
3058 // we can consider removing support for broadcasting later.
3059 if (iw != ShapedType::kDynamic && iw != 1) {
3060 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3061 const std::optional<int64_t> calculatedOutWidthMinusOne =
3062 idivCheck(scaledInWidth, scaleXD);
3063 if (!calculatedOutWidthMinusOne.has_value())
3064 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3065 "border_x ")
3066 << "to be wholly divisible by scale_x_d, got ((" << iw
3067 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3068 << ") / " << scaleXD;
3069 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3070 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3071 return emitOpError("calculated output width did not match expected: ")
3072 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3073 }
3074
3075 return success();
3076}
3077
3078LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3079 MLIRContext *context, ::std::optional<Location> location,
3080 ScatterOp::Adaptor adaptor,
3081 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3082 llvm::SmallVector<int64_t> outputShape;
3083 outputShape.resize(3, ShapedType::kDynamic);
3084
3085 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3086 if (valuesInShape.hasRank()) {
3087 outputShape[0] = valuesInShape.getDimSize(0);
3088 outputShape[1] = valuesInShape.getDimSize(1);
3089 outputShape[2] = valuesInShape.getDimSize(2);
3090 }
3091
3092 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3093 if (indicesShape.hasRank()) {
3094 if (outputShape[0] == ShapedType::kDynamic)
3095 outputShape[0] = indicesShape.getDimSize(0);
3096 }
3097
3098 ShapeAdaptor inputShape(adaptor.getInput().getType());
3099 if (inputShape.hasRank()) {
3100 if (outputShape[0] == ShapedType::kDynamic)
3101 outputShape[0] = inputShape.getDimSize(0);
3102 if (outputShape[2] == ShapedType::kDynamic)
3103 outputShape[2] = inputShape.getDimSize(2);
3104 }
3105
3106 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3107 return success();
3108}
3109
3110LogicalResult tosa::ScatterOp::verify() {
3111 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3112 /* outType = */ getValuesOut().getType())
3113 .failed() ||
3114 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3115 /* outType = */ getValuesOut().getType())
3116 .failed()) {
3117 return failure();
3118 }
3119
3120 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3121 const ShapeAdaptor indicesShape(getIndices().getType());
3122 const ShapeAdaptor inputShape(getInput().getType());
3123 const ShapeAdaptor outputShape(getValuesOut().getType());
3124
3125 int64_t n = ShapedType::kDynamic;
3126 int64_t k = ShapedType::kDynamic;
3127 int64_t w = ShapedType::kDynamic;
3128 int64_t c = ShapedType::kDynamic;
3129 if (valuesInShape.hasRank()) {
3130 n = valuesInShape.getDimSize(0);
3131 k = valuesInShape.getDimSize(1);
3132 c = valuesInShape.getDimSize(2);
3133 }
3134 if (indicesShape.hasRank()) {
3135 const int64_t indicesN = indicesShape.getDimSize(0);
3136 w = indicesShape.getDimSize(1);
3137 if (n == ShapedType::kDynamic)
3138 n = indicesN;
3139 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3140 return emitOpError() << "requires indices dimension 0 to have size " << n
3141 << ", got " << indicesN;
3142 }
3143 if (inputShape.hasRank()) {
3144 const int64_t inputN = inputShape.getDimSize(0);
3145 const int64_t inputW = inputShape.getDimSize(1);
3146 const int64_t inputC = inputShape.getDimSize(2);
3147 if (n == ShapedType::kDynamic)
3148 n = inputN;
3149 else if (inputN != ShapedType::kDynamic && n != inputN)
3150 return emitOpError() << "requires input dimension 0 to have size " << n
3151 << ", got " << inputN;
3152 if (w == ShapedType::kDynamic)
3153 w = inputW;
3154 else if (inputW != ShapedType::kDynamic && w != inputW)
3155 return emitOpError() << "requires input dimension 1 to have size " << w
3156 << ", got " << inputW;
3157
3158 if (c == ShapedType::kDynamic)
3159 c = inputC;
3160 else if (inputC != ShapedType::kDynamic && c != inputC)
3161 return emitOpError() << "requires input dimension 2 to have size " << c
3162 << ", got " << inputC;
3163 }
3164 if (outputShape.hasRank()) {
3165 const int64_t outputN = outputShape.getDimSize(0);
3166 const int64_t outputK = outputShape.getDimSize(1);
3167 const int64_t outputC = outputShape.getDimSize(2);
3168 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3169 n != outputN)
3170 return emitOpError() << "requires values_out dimension 0 to have size "
3171 << n << ", got " << outputN;
3172 if (k == ShapedType::kDynamic)
3173 k = outputK;
3174 else if (outputK != ShapedType::kDynamic && k != outputK)
3175 return emitOpError() << "requires values_out dimension 1 to have size "
3176 << k << ", got " << outputK;
3177 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3178 c != outputC)
3179 return emitOpError() << "requires values_out dimension 2 to have size "
3180 << c << ", got " << outputC;
3181 }
3182 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3183 return emitOpError() << "requires dimensions K >= W, got K=" << k
3184 << " and W=" << w;
3185
3186 return success();
3187}
3188
3189static LogicalResult ReduceInferReturnTypes(
3190 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3191 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3192 int64_t axisVal = axis.getValue().getSExtValue();
3193 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3194 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3195 return success();
3196 }
3197
3198 SmallVector<int64_t> outputShape;
3199 operandShape.getDims(outputShape);
3200 outputShape[axisVal] = 1;
3201 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3202 return success();
3203}
3204
3205#define COMPATIBLE_RETURN_TYPES(OP) \
3206 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3207 if (l.size() != r.size() || l.size() != 1) \
3208 return false; \
3209 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3210 return false; \
3211 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3212 }
3213
3214#define REDUCE_SHAPE_INFER(OP) \
3215 LogicalResult OP::inferReturnTypeComponents( \
3216 MLIRContext *context, ::std::optional<Location> location, \
3217 OP::Adaptor adaptor, \
3218 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3219 Type inputType = \
3220 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3221 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3222 const Properties &prop = adaptor.getProperties(); \
3223 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3224 inferredReturnShapes); \
3225 } \
3226 COMPATIBLE_RETURN_TYPES(OP)
3227
3228REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3229REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3230REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3231REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3232REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3233REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3234#undef REDUCE_SHAPE_INFER
3235COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3236#undef COMPATIBLE_RETURN_TYPES
3237
3238template <typename T>
3239static LogicalResult verifyReduceOp(T op) {
3240 // All TOSA reduce Ops have input, output and axis.
3241 TensorType inputType = op.getInput().getType();
3242 TensorType outputType = op.getOutput().getType();
3243 int32_t reduceAxis = op.getAxis();
3244
3245 if (reduceAxis < 0) {
3246 op.emitOpError("reduce axis must not be negative");
3247 return failure();
3248 }
3249 if (inputType.hasRank()) {
3250 int64_t inputRank = inputType.getRank();
3251 // We allow for a special case where the input/output shape has rank 0 and
3252 // axis is also 0.
3253 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3254 op.emitOpError("expect input tensor rank (")
3255 << inputRank << ") to be larger than reduce axis (" << reduceAxis
3256 << ")";
3257 return failure();
3258 }
3259 }
3260 if (outputType.hasRank()) {
3261 int64_t outputRank = outputType.getRank();
3262 if (inputType.hasRank() && outputRank != inputType.getRank()) {
3263 op.emitOpError(
3264 "expect output tensor rank to be equal to input tensor rank");
3265 return failure();
3266 }
3267 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3268 op.emitOpError("expect output tensor rank (")
3269 << outputRank << ") to be larger than reduce axis (" << reduceAxis
3270 << ")";
3271 return failure();
3272 }
3273 // We can only verify the reduced dimension size to be 1 if this is not
3274 // the special case of output rank == 0.
3275 if (outputRank != 0) {
3276 auto outputShape = outputType.getShape();
3277 if (!outputType.isDynamicDim(reduceAxis) &&
3278 outputShape[reduceAxis] != 1) {
3279 op.emitOpError("expect reduced dimension size to be 1, got ")
3280 << outputShape[reduceAxis];
3281 return failure();
3282 }
3283 }
3284 }
3285 return success();
3286}
3287
3288LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3289LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3290LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3291LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3292LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3293LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3294
3295static LogicalResult NAryInferReturnTypes(
3296 const ValueShapeRange &operands,
3297 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3299 if (resolveBroadcastShape(operands, outShape).failed()) {
3300 inferredReturnShapes.push_back(ShapedTypeComponents());
3301 } else {
3302 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3303 }
3304 return success();
3305}
3306
3307#define NARY_SHAPE_INFER(OP) \
3308 LogicalResult OP::inferReturnTypeComponents( \
3309 MLIRContext *context, ::std::optional<Location> location, \
3310 ValueShapeRange operands, DictionaryAttr attributes, \
3311 OpaqueProperties properties, RegionRange regions, \
3312 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3313 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3314 }
3315
3316NARY_SHAPE_INFER(tosa::AbsOp)
3317NARY_SHAPE_INFER(tosa::AddOp)
3318NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3319NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3320NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3321NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3322NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3323NARY_SHAPE_INFER(tosa::CastOp)
3324NARY_SHAPE_INFER(tosa::CeilOp)
3325NARY_SHAPE_INFER(tosa::ClampOp)
3326NARY_SHAPE_INFER(tosa::ClzOp)
3327NARY_SHAPE_INFER(tosa::CosOp)
3328NARY_SHAPE_INFER(tosa::ExpOp)
3329NARY_SHAPE_INFER(tosa::FloorOp)
3330NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3331NARY_SHAPE_INFER(tosa::GreaterOp)
3332NARY_SHAPE_INFER(tosa::IdentityOp)
3333NARY_SHAPE_INFER(tosa::IntDivOp)
3334NARY_SHAPE_INFER(tosa::LogOp)
3335NARY_SHAPE_INFER(tosa::LogicalAndOp)
3336NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3337NARY_SHAPE_INFER(tosa::LogicalNotOp)
3338NARY_SHAPE_INFER(tosa::LogicalOrOp)
3339NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3340NARY_SHAPE_INFER(tosa::LogicalXorOp)
3341NARY_SHAPE_INFER(tosa::MaximumOp)
3342NARY_SHAPE_INFER(tosa::MinimumOp)
3343NARY_SHAPE_INFER(tosa::PowOp)
3344NARY_SHAPE_INFER(tosa::ReciprocalOp)
3345NARY_SHAPE_INFER(tosa::ReverseOp)
3346NARY_SHAPE_INFER(tosa::RsqrtOp)
3347NARY_SHAPE_INFER(tosa::SinOp)
3348NARY_SHAPE_INFER(tosa::SelectOp)
3349NARY_SHAPE_INFER(tosa::SubOp)
3350NARY_SHAPE_INFER(tosa::TanhOp)
3351NARY_SHAPE_INFER(tosa::ErfOp)
3352NARY_SHAPE_INFER(tosa::SigmoidOp)
3353#undef PRED_SHAPE_INFER
3354
3355LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3356 MLIRContext *context, ::std::optional<Location> location,
3357 NegateOp::Adaptor adaptor,
3358 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3359 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3360 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3361 return success();
3362}
3363
3364LogicalResult tosa::NegateOp::verify() {
3365 // Verify same element type
3366 const Type input1Type = getInput1().getType();
3367 const Type outputType = getOutput().getType();
3368 if (verifySameElementTypes(*this, input1Type, outputType).failed())
3369 return failure();
3370
3371 // Verify same shape
3372 const SmallVector<Type, 2> types = {input1Type, outputType};
3373 if (failed(verifyCompatibleShapes(types)))
3374 return emitOpError() << "requires the same shape for input1 and output";
3375
3376 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3377 const Type input1ZpEType =
3378 getStorageElementTypeOrSelf(getInput1Zp().getType());
3379 if (input1EType != input1ZpEType) {
3380 return emitOpError("expect both input1 and its zero point are the same "
3381 "element type, got ")
3382 << input1EType << " and " << input1ZpEType;
3383 }
3384 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3385 const Type outputZpEType =
3386 getStorageElementTypeOrSelf(getOutputZp().getType());
3387 if (outputEType != outputZpEType) {
3388 return emitOpError("expect both output and its zero point are the same "
3389 "element type, got ")
3390 << outputEType << " and " << outputZpEType;
3391 }
3392
3393 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3394 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3395 return failure();
3396
3397 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3398 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3399 return failure();
3400
3401 return success();
3402}
3403
3404static LogicalResult poolingInferReturnTypes(
3405 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3407 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3408 llvm::SmallVector<int64_t> outputShape;
3409 outputShape.resize(4, ShapedType::kDynamic);
3410
3411 // We only know the rank if the input type is unranked.
3412 if (!inputShape.hasRank()) {
3413 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3414 return success();
3415 }
3416
3417 // Batch and number of channels are identical for pooling layer.
3418 outputShape[0] = inputShape.getDimSize(0);
3419 outputShape[3] = inputShape.getDimSize(3);
3420
3421 int64_t height = inputShape.getDimSize(1);
3422 int64_t width = inputShape.getDimSize(2);
3423
3424 if (ShapedType::isStatic(height)) {
3425 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3426 outputShape[1] = padded / stride[0] + 1;
3427 }
3428
3429 if (ShapedType::isStatic(width)) {
3430 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3431 outputShape[2] = padded / stride[1] + 1;
3432 }
3433
3434 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3435 return success();
3436}
3437
3438LogicalResult Conv2DOp::inferReturnTypeComponents(
3439 MLIRContext *context, ::std::optional<Location> location,
3440 Conv2DOp::Adaptor adaptor,
3441 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3442 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3443
3444 int64_t inputWidth = ShapedType::kDynamic;
3445 int64_t inputHeight = ShapedType::kDynamic;
3446 int64_t weightWidth = ShapedType::kDynamic;
3447 int64_t weightHeight = ShapedType::kDynamic;
3448
3449 // Input shape describes input width/height and batch.
3450
3451 ShapeAdaptor inputShape(adaptor.getInput().getType());
3452 if (inputShape.hasRank()) {
3453 outputShape[0] = inputShape.getDimSize(0);
3454 inputHeight = inputShape.getDimSize(1);
3455 inputWidth = inputShape.getDimSize(2);
3456 }
3457
3458 // Weight shapes describes the filter width/height and the output channels.
3459 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3460 if (weightShape.hasRank()) {
3461 outputShape[3] = weightShape.getDimSize(0);
3462 weightHeight = weightShape.getDimSize(1);
3463 weightWidth = weightShape.getDimSize(2);
3464 }
3465
3466 // Bias shape can describe the output channels.
3467 ShapeAdaptor biasShape(adaptor.getBias().getType());
3468 if (biasShape.hasRank()) {
3469 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3470 ? biasShape.getDimSize(0)
3471 : outputShape[3];
3472 }
3473
3474 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3475 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3476 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3477
3478 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3479 int64_t inputSize = inputHeight + padding[0] + padding[1];
3480 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3481 int64_t unstridedResult = inputSize - filterSize + 1;
3482 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3483 }
3484
3485 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3486 int64_t inputSize = inputWidth + padding[2] + padding[3];
3487 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3488 int64_t unstridedResult = inputSize - filterSize + 1;
3489 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3490 }
3491
3492 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3493 return success();
3494}
3495
3496LogicalResult Conv2DOp::verify() {
3497 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3498 verifyConvOpErrorIf(*this).failed())
3499 return failure();
3500 return success();
3501}
3502
3503LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
3504 MLIRContext *context, ::std::optional<Location> location,
3505 Conv2DBlockScaledOp::Adaptor adaptor,
3506 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3507 SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
3508
3509 int64_t inputWidth = ShapedType::kDynamic;
3510 int64_t inputHeight = ShapedType::kDynamic;
3511 int64_t weightWidth = ShapedType::kDynamic;
3512 int64_t weightHeight = ShapedType::kDynamic;
3513
3514 // Input shape describes input width/height and batch.
3515 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
3516 if (inputDataShape.hasRank()) {
3517 outShape[0] = inputDataShape.getDimSize(0);
3518 inputHeight = inputDataShape.getDimSize(1);
3519 inputWidth = inputDataShape.getDimSize(2);
3520 }
3521 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
3522 if (inputScaleShape.hasRank()) {
3523 outShape[0] = ShapedType::isDynamic(outShape[0])
3524 ? inputScaleShape.getDimSize(0)
3525 : outShape[0];
3526 inputHeight = ShapedType::isDynamic(inputHeight)
3527 ? inputScaleShape.getDimSize(1)
3528 : inputHeight;
3529 inputWidth = ShapedType::isDynamic(inputWidth)
3530 ? inputScaleShape.getDimSize(2)
3531 : inputWidth;
3532 }
3533
3534 // Weight shapes describes the filter width/height and the output channels.
3535 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
3536 if (weightDataShape.hasRank()) {
3537 outShape[3] = weightDataShape.getDimSize(0);
3538 weightHeight = weightDataShape.getDimSize(1);
3539 weightWidth = weightDataShape.getDimSize(2);
3540 }
3541 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
3542 if (weightScaleShape.hasRank()) {
3543 outShape[3] = ShapedType::isDynamic(outShape[3])
3544 ? weightScaleShape.getDimSize(0)
3545 : outShape[3];
3546 weightHeight = ShapedType::isDynamic(weightHeight)
3547 ? weightScaleShape.getDimSize(1)
3548 : weightHeight;
3549 weightWidth = ShapedType::isDynamic(weightWidth)
3550 ? weightScaleShape.getDimSize(2)
3551 : weightWidth;
3552 }
3553
3554 // Bias shape can describe the output channels.
3555 const ShapeAdaptor biasShape(adaptor.getBias().getType());
3556 if (biasShape.hasRank()) {
3557 const int64_t biasSize = biasShape.getDimSize(0);
3558 // Bias of size 1 may be broadcast
3559 if (biasSize != 1) {
3560 outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
3561 }
3562 }
3563
3564 SmallVector<int64_t> padValues;
3565 SmallVector<int64_t> strideValues;
3566 SmallVector<int64_t> dilationValues;
3567 if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) ||
3568 !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
3569 strideValues) ||
3570 !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
3571 dilationValues)) {
3572 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3573 return success();
3574 }
3575
3576 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3577 const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
3578 const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
3579 const int64_t unstridedResult = inputSize - filterSize + 1;
3580 outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
3581 }
3582
3583 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3584 const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
3585 const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
3586 const int64_t unstridedResult = inputSize - filterSize + 1;
3587 outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
3588 }
3589
3590 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3591 return success();
3592}
3593
3594LogicalResult Conv2DBlockScaledOp::verify() {
3595 if (failed(verifySameElementTypes(*this, getInputData().getType(),
3596 getWeightData().getType(), "input_data",
3597 "weight_data")) ||
3598 failed(verifySameElementTypes(*this, getInputScale().getType(),
3599 getWeightScale().getType(), "input_scale",
3600 "weight_scale")) ||
3601 failed(verifySameElementTypes(*this, getBias().getType(),
3602 getOutput().getType(), "bias", "output")))
3603 return failure();
3604
3605 // Verify input shape compatibility
3606 int64_t N = ShapedType::kDynamic;
3607 int64_t IH = ShapedType::kDynamic;
3608 int64_t IW = ShapedType::kDynamic;
3609 int64_t IC = ShapedType::kDynamic;
3610 int64_t multiplesOfIC = ShapedType::kDynamic;
3611 int64_t OC = ShapedType::kDynamic;
3612 int64_t KH = ShapedType::kDynamic;
3613 int64_t KW = ShapedType::kDynamic;
3614
3615 const ShapeAdaptor inputDataShape(getInputData().getType());
3616 if (inputDataShape.hasRank()) {
3617 N = inputDataShape.getDimSize(0);
3618 IH = inputDataShape.getDimSize(1);
3619 IW = inputDataShape.getDimSize(2);
3620 IC = inputDataShape.getDimSize(3);
3621 }
3622
3623 const ShapeAdaptor inputScaleShape(getInputScale().getType());
3624 if (inputScaleShape.hasRank()) {
3625 if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
3626 "input_scale", "batch size")) ||
3627 failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
3628 "input_scale", "input height")) ||
3629 failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
3630 "input_scale", "input width")))
3631 return failure();
3632 multiplesOfIC = inputScaleShape.getDimSize(3);
3633 }
3634
3635 const ShapeAdaptor weightDataShape(getWeightData().getType());
3636 if (weightDataShape.hasRank()) {
3637 OC = weightDataShape.getDimSize(0);
3638 KH = weightDataShape.getDimSize(1);
3639 KW = weightDataShape.getDimSize(2);
3640 if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
3641 "weight_data", "input channels")))
3642 return failure();
3643 }
3644
3645 const ShapeAdaptor weightScaleShape(getWeightScale().getType());
3646 if (weightScaleShape.hasRank()) {
3647 if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
3648 "weight_scale", "output channels")) ||
3649 failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
3650 "weight_scale", "kernel height")) ||
3651 failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
3652 "weight_scale", "kernel width")) ||
3653 failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
3654 weightScaleShape.getDimSize(3),
3655 "weight_scale", "input channel blocks")))
3656 return failure();
3657 }
3658
3659 // Verify IC is a multiple of block size
3660 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
3661 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
3662 return emitOpError("expect IC to be a multiple of block size, got IC=")
3663 << IC << ", block_size=" << blockSize;
3664
3665 // Verify multiplesOfIC is IC / block size
3666 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
3667 multiplesOfIC != IC / blockSize)
3668 return emitOpError(
3669 "expect scale operands dimension 2 to equal IC/block_size (")
3670 << IC << "/" << blockSize << ")"
3671 << ", got " << multiplesOfIC;
3672
3673 // Verify pad/stride/dilation values
3674 SmallVector<int64_t> padValues;
3675 if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
3676 if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
3677 return emitOpError("expect all padding values to be >= 0, got ")
3678 << padValues;
3679 }
3680
3681 SmallVector<int64_t> strideValues;
3682 if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
3683 if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
3684 return emitOpError("expect all stride values to be >= 1, got ")
3685 << strideValues;
3686 }
3687
3688 SmallVector<int64_t> dilationValues;
3689 if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
3690 dilationValues)) {
3691 if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
3692 return emitOpError("expect all dilation values to be >= 1, got ")
3693 << dilationValues;
3694 }
3695
3696 // Verify output shape compatibility
3697 const ShapeAdaptor outputShape(getOutput().getType());
3698 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
3699 outputShape.hasRank()) {
3700 if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
3701 padValues[0], padValues[1], strideValues[0],
3702 dilationValues[0], "height", "y", "top",
3703 "bottom")) ||
3704 failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
3705 padValues[2], padValues[3], strideValues[1],
3706 dilationValues[1], "width", "x", "left",
3707 "right")))
3708 return failure();
3709 }
3710
3711 // Verify bias
3712 const ShapeAdaptor biasShape(getBias().getType());
3713 if (biasShape.hasRank() && outputShape.hasRank()) {
3714 const int64_t biasChannels = biasShape.getDimSize(0);
3715 const int64_t outputChannels =
3716 outputShape.getDimSize(outputShape.getRank() - 1);
3717 if (biasChannels == ShapedType::kDynamic ||
3718 outputChannels == ShapedType::kDynamic)
3719 // Skip following checks if biasChannels or outputChannels is dynamic dim
3720 return success();
3721
3722 if (biasChannels != outputChannels && biasChannels != 1)
3723 return emitOpError(
3724 "bias channels expected to be equal to output channels (")
3725 << outputChannels << ") or 1, got " << biasChannels;
3726 }
3727
3728 return success();
3729}
3730
3731LogicalResult Conv3DOp::inferReturnTypeComponents(
3732 MLIRContext *context, ::std::optional<Location> location,
3733 Conv3DOp::Adaptor adaptor,
3734 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3735 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3736
3737 int64_t inputWidth = ShapedType::kDynamic;
3738 int64_t inputHeight = ShapedType::kDynamic;
3739 int64_t inputDepth = ShapedType::kDynamic;
3740
3741 int64_t weightWidth = ShapedType::kDynamic;
3742 int64_t weightHeight = ShapedType::kDynamic;
3743 int64_t weightDepth = ShapedType::kDynamic;
3744
3745 // Input shape describes input width/height and batch.
3746 ShapeAdaptor inputShape(adaptor.getInput().getType());
3747 if (inputShape.hasRank()) {
3748 outputShape[0] = inputShape.getDimSize(0);
3749 inputDepth = inputShape.getDimSize(1);
3750 inputHeight = inputShape.getDimSize(2);
3751 inputWidth = inputShape.getDimSize(3);
3752 }
3753
3754 // Weight shapes describes the filter width/height and the output channels.
3755 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3756 if (weightShape.hasRank()) {
3757 outputShape[4] = weightShape.getDimSize(0);
3758 weightDepth = weightShape.getDimSize(1);
3759 weightHeight = weightShape.getDimSize(2);
3760 weightWidth = weightShape.getDimSize(3);
3761 }
3762
3763 // Bias shape can describe the output channels.
3764 ShapeAdaptor biasShape(adaptor.getBias().getType());
3765 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3766 outputShape[4] = biasShape.getDimSize(0);
3767 }
3768
3769 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3770 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3771 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3772
3773 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3774 int32_t inputSize = inputDepth + pad[0] + pad[1];
3775 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3776 int32_t unstridedResult = inputSize - filterSize + 1;
3777 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3778 }
3779
3780 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3781 int32_t inputSize = inputHeight + pad[2] + pad[3];
3782 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3783 int32_t unstridedResult = inputSize - filterSize + 1;
3784 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3785 }
3786
3787 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3788 int32_t inputSize = inputWidth + pad[4] + pad[5];
3789 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3790 int32_t unstridedResult = inputSize - filterSize + 1;
3791 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3792 }
3793
3794 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3795 return success();
3796}
3797
3798LogicalResult Conv3DOp::verify() {
3799 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3800 verifyConvOpErrorIf(*this).failed())
3801 return failure();
3802 return success();
3803}
3804
3805LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3806 MLIRContext *context, ::std::optional<Location> location,
3807 AvgPool2dOp::Adaptor adaptor,
3808 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3809 ShapeAdaptor inputShape(adaptor.getInput().getType());
3810 const Properties &prop = adaptor.getProperties();
3811 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3812 inferredReturnShapes);
3813}
3814
3815LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3816 MLIRContext *context, ::std::optional<Location> location,
3817 MaxPool2dOp::Adaptor adaptor,
3818 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3819 ShapeAdaptor inputShape(adaptor.getInput().getType());
3820 const Properties &prop = adaptor.getProperties();
3821 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3822 inferredReturnShapes);
3823}
3824
3825LogicalResult MaxPool2dOp::verify() {
3826 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3827 /* outType = */ getOutput().getType())))
3828 return failure();
3829
3830 if (failed(verifyPoolingOp(*this)))
3831 return failure();
3832
3833 return success();
3834}
3835
3836LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3837 MLIRContext *context, ::std::optional<Location> location,
3838 DepthwiseConv2DOp::Adaptor adaptor,
3839 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3840 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3841
3842 int64_t inputWidth = ShapedType::kDynamic;
3843 int64_t inputHeight = ShapedType::kDynamic;
3844 int64_t inputChannels = ShapedType::kDynamic;
3845
3846 int64_t weightWidth = ShapedType::kDynamic;
3847 int64_t weightHeight = ShapedType::kDynamic;
3848 int64_t depthChannels = ShapedType::kDynamic;
3849
3850 // Input shape describes input width/height and batch.
3851 ShapeAdaptor inputShape(adaptor.getInput().getType());
3852 if (inputShape.hasRank()) {
3853 outputShape[0] = inputShape.getDimSize(0);
3854 inputHeight = inputShape.getDimSize(1);
3855 inputWidth = inputShape.getDimSize(2);
3856 inputChannels = inputShape.getDimSize(3);
3857 }
3858
3859 // Weight shapes describes the filter width/height and the output channels.
3860 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3861 if (weightShape.hasRank()) {
3862 weightHeight = weightShape.getDimSize(0);
3863 weightWidth = weightShape.getDimSize(1);
3864 inputChannels = ShapedType::isDynamic(inputChannels)
3865 ? weightShape.getDimSize(2)
3866 : inputChannels;
3867 depthChannels = weightShape.getDimSize(3);
3868 }
3869
3870 // If both inputChannels and depthChannels are available we can determine
3871 // the output channels.
3872 if (ShapedType::isStatic(inputChannels) &&
3873 ShapedType::isStatic(depthChannels)) {
3874 outputShape[3] = inputChannels * depthChannels;
3875 }
3876
3877 // Bias shape can describe the output channels.
3878 ShapeAdaptor biasShape(adaptor.getBias().getType());
3879 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
3880 int64_t bc = biasShape.getDimSize(0);
3881 if (bc != ShapedType::kDynamic && bc != 1)
3882 outputShape[3] = bc;
3883 }
3884
3885 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3886 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3887 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3888
3889 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3890 int64_t inputSize = inputHeight + padding[0] + padding[1];
3891 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3892 int64_t unstridedResult = inputSize - filterSize + 1;
3893 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3894 }
3895
3896 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3897 int64_t inputSize = inputWidth + padding[2] + padding[3];
3898 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3899 int64_t unstridedResult = inputSize - filterSize + 1;
3900 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3901 }
3902
3903 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3904 return success();
3905}
3906
3907LogicalResult DepthwiseConv2DOp::verify() {
3908 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3909 verifyConvOpErrorIf(*this).failed())
3910 return failure();
3911 return success();
3912}
3913
3914LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3915 MLIRContext *context, ::std::optional<Location> location,
3916 TransposeConv2DOp::Adaptor adaptor,
3917 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3918 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3919
3920 int64_t inputWidth = ShapedType::kDynamic;
3921 int64_t inputHeight = ShapedType::kDynamic;
3922 int64_t weightWidth = ShapedType::kDynamic;
3923 int64_t weightHeight = ShapedType::kDynamic;
3924
3925 // Input shape describes input width/height and batch.
3926 ShapeAdaptor inputShape(adaptor.getInput().getType());
3927 if (inputShape.hasRank()) {
3928 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3929 ? inputShape.getDimSize(0)
3930 : outputShape[0];
3931 inputHeight = inputShape.getDimSize(1);
3932 inputWidth = inputShape.getDimSize(2);
3933 }
3934
3935 // Weight shapes describes the filter width/height and the output channels.
3936 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3937 if (weightShape.hasRank()) {
3938 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3939 ? weightShape.getDimSize(0)
3940 : outputShape[3];
3941 weightHeight = weightShape.getDimSize(1);
3942 weightWidth = weightShape.getDimSize(2);
3943 }
3944
3945 // Bias shape can describe the output channels.
3946 ShapeAdaptor biasShape(adaptor.getBias().getType());
3947 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
3948 int64_t bc = biasShape.getDimSize(0);
3949 if (bc != ShapedType::kDynamic && bc != 1)
3950 outputShape[3] = bc;
3951 }
3952
3953 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3954 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3955
3956 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3957 int64_t calculateSize =
3958 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3959 outputShape[1] =
3960 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3961 }
3962
3963 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3964 int64_t calculateSize =
3965 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3966 outputShape[2] =
3967 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3968 }
3969
3970 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3971 return success();
3972}
3973
3974LogicalResult TransposeConv2DOp::verify() {
3975 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3976 return failure();
3977
3978 const llvm::ArrayRef<int64_t> strides = getStride();
3979 const int64_t strideY = strides[0];
3980 const int64_t strideX = strides[1];
3981
3982 if (strideY < 1 || strideX < 1)
3983 return emitOpError("expect all stride values to be >= 1, got [")
3984 << strides << "]";
3985
3986 const auto checkPadAgainstKernelDim =
3987 [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
3988 llvm::StringRef kernelDimName) -> LogicalResult {
3989 if (padValue <= -kernelDimSize)
3990 return emitOpError("expected ")
3991 << padName << " > -" << kernelDimName << ", but got: " << padName
3992 << "=" << padValue << " and " << kernelDimName << "="
3993 << kernelDimSize;
3994 return success();
3995 };
3996
3997 const llvm::ArrayRef<int64_t> padding = getOutPad();
3998 const int64_t outPadTop = padding[0];
3999 const int64_t outPadBottom = padding[1];
4000 const int64_t outPadLeft = padding[2];
4001 const int64_t outPadRight = padding[3];
4002
4003 const auto weightType =
4004 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
4005
4006 if (weightType) {
4007 const int64_t kernelHeight = weightType.getDimSize(1);
4008 if (ShapedType::isStatic(kernelHeight)) {
4009 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4010 "out_pad_top", "KH")))
4011 return failure();
4012
4013 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4014 "out_pad_bottom", "KH")))
4015 return failure();
4016 }
4017
4018 const int64_t kernelWidth = weightType.getDimSize(2);
4019 if (ShapedType::isStatic(kernelWidth)) {
4020 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4021 "out_pad_left", "KW")))
4022 return failure();
4023
4024 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4025 "out_pad_right", "KW")))
4026 return failure();
4027 }
4028 }
4029
4030 // Rest of the checks depend on the output type being a RankedTensorType
4031 const auto outputType =
4032 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
4033 if (!outputType)
4034 return success();
4035
4036 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
4037 if (inputType && weightType) {
4038 const int64_t inputHeight = inputType.getDimSize(1);
4039 const int64_t kernelHeight = weightType.getDimSize(1);
4040 const int64_t outputHeight = outputType.getDimSize(1);
4041
4042 if (ShapedType::isStatic(inputHeight) &&
4043 ShapedType::isStatic(outputHeight)) {
4044 if (outputHeight !=
4045 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4046 return emitOpError(
4047 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4048 "+ out_pad_top + out_pad_bottom + KH, but got ")
4049 << outputHeight << " != (" << inputHeight << " - 1) * "
4050 << strideY << " + " << outPadTop << " + " << outPadBottom
4051 << " + " << kernelHeight;
4052 }
4053
4054 const int64_t inputWidth = inputType.getDimSize(2);
4055 const int64_t kernelWidth = weightType.getDimSize(2);
4056 const int64_t outputWidth = outputType.getDimSize(2);
4057
4058 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4059 if (outputWidth !=
4060 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4061 return emitOpError(
4062 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4063 "+ out_pad_left + out_pad_right + KW, but got ")
4064 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
4065 << " + " << outPadLeft << " + " << outPadRight << " + "
4066 << kernelWidth;
4067 }
4068 }
4069
4070 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
4071
4072 if (!biasType)
4073 return success();
4074
4075 const int64_t biasChannels = biasType.getDimSize(0);
4076
4077 // Skip further checks if bias is dynamic
4078 if (biasChannels == ShapedType::kDynamic)
4079 return success();
4080
4081 const int64_t outputChannels = outputType.getDimSize(3);
4082 if (!ShapedType::isDynamic(outputChannels) &&
4083 biasChannels != outputChannels && biasChannels != 1)
4084 return emitOpError(
4085 "bias channels expected to be equal to output channels (")
4086 << outputChannels << ") or 1, got " << biasChannels;
4087
4088 return success();
4089}
4090
4091LogicalResult RescaleOp::verify() {
4092 auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
4093 if (!inputType) {
4094 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
4095 return failure();
4096 }
4097
4098 auto inputElementType =
4099 getStorageElementTypeOrSelf(inputType.getElementType());
4100 if (!mlir::isa<IntegerType>(inputElementType)) {
4101 emitOpError("expect input to have integer element type, got ")
4102 << inputElementType;
4103 return failure();
4104 }
4105
4106 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
4107 if (!outputType) {
4108 emitOpError("expect shaped tensor for output, got ")
4109 << getOutput().getType();
4110 return failure();
4111 }
4112
4113 auto outputElementType =
4114 getStorageElementTypeOrSelf(outputType.getElementType());
4115 if (!mlir::isa<IntegerType>(outputElementType)) {
4116 emitOpError("expect output to have integer element type, got ")
4117 << outputElementType;
4118 return failure();
4119 }
4120
4121 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
4122 .failed())
4123 return failure();
4124
4125 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
4126 .failed())
4127 return failure();
4128
4129 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4130 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4131 return failure();
4132
4133 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4134 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4135 return failure();
4136
4137 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
4138 if (!multiplierType) {
4139 emitOpError("expect shaped tensor for multiplier, got ")
4140 << getMultiplier().getType();
4141 return failure();
4142 }
4143
4144 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
4145 if (!shiftType) {
4146 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
4147 return failure();
4148 }
4149
4150 // multiplier element type must be i32 for scale32 = true
4151 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4152 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
4153 << multiplierType.getElementType();
4154 return failure();
4155 }
4156
4157 // multiplier element type must be i16 for scale32 = false
4158 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4160 "expect i16 element type for multiplier for scale32=false, got ")
4161 << multiplierType.getElementType();
4162 return failure();
4163 }
4164
4165 if (!inputType.hasRank())
4166 return success();
4167
4168 // multiplier/shift must have shape = {numChannels},
4169 // where numChannel is 1 if per_channel = false
4170 // otherwise numChannel is dimension in input shape's last axis
4171 int64_t numChannels = 1;
4172 if (getPerChannel()) {
4173 if (inputType.getRank() < 1) {
4174 emitOpError("requires input to be at least rank 1 when per_channel is "
4175 "true, but got rank ")
4176 << inputType.getRank();
4177 return failure();
4178 }
4179 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4180 }
4181
4182 if (!multiplierType.hasRank())
4183 return success();
4184
4185 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4186 // multiplier input has rank 1 by dialect definition
4187 if (multiplierShape[0] != ShapedType::kDynamic &&
4188 multiplierShape[0] != numChannels) {
4189 emitOpError("expect shape of { ")
4190 << numChannels << " } for multiplier input, got { "
4191 << multiplierShape[0] << " }";
4192 return failure();
4193 }
4194
4195 if (!shiftType.hasRank())
4196 return success();
4197
4198 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4199 // shift input has rank 1 by dialect definition
4200 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4201 emitOpError("expect shape of { ")
4202 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
4203 return failure();
4204 }
4205
4206 return success();
4207}
4208
4209LogicalResult RescaleOp::inferReturnTypeComponents(
4210 MLIRContext *context, ::std::optional<Location> location,
4211 RescaleOp::Adaptor adaptor,
4212 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4213 ShapeAdaptor inputShape(adaptor.getInput().getType());
4214 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4215 return success();
4216}
4217
4218LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4219 MLIRContext *context, ::std::optional<Location> location,
4220 CastFromBlockScaledOp::Adaptor adaptor,
4221 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4222 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4223 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4224 return success();
4225}
4226
4227LogicalResult CastFromBlockScaledOp::verify() {
4228 const Type inputDataType = getInputData().getType();
4229 const Type outputDataType = getResult().getType();
4230 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4231 return emitOpError() << "require compatible shapes for input_data ("
4232 << inputDataType << ") and " << "output_data ("
4233 << outputDataType << ")";
4234
4235 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4236
4237 if (inputDataShape.hasRank()) {
4238 const unsigned int blockSize =
4239 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4240 const int64_t inputDataLastDim =
4241 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4242 if (inputDataLastDim % blockSize != 0)
4243 return emitOpError() << "expect last dimension of input_data ("
4244 << inputDataLastDim
4245 << ") to be divisible by block_size (" << blockSize
4246 << ")";
4247
4248 const Type inputScaleType = getInputScale().getType();
4249 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4250
4251 if (inputScaleShape.hasRank()) {
4252 SmallVector<int64_t> inputDataDims, inputScaleDims;
4253 inputDataShape.getDims(inputDataDims);
4254 inputScaleShape.getDims(inputScaleDims);
4255
4256 if (inputDataDims.size() != inputScaleDims.size() ||
4258 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4259 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4260 return emitOpError()
4261 << "require compatible shapes for input_data (" << inputDataType
4262 << ") and " << "input_scale (" << inputScaleType
4263 << ") except for the last dimension";
4264
4265 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4266 inputScaleDims.back()};
4267 if (ShapedType::isStatic(inputDataLastDim) &&
4268 failed(verifyCompatibleDims(dimsToCheck)))
4269 return emitOpError()
4270 << "expect last dimension of input_scale ("
4271 << inputScaleDims.back()
4272 << ") to be equal to last dimension of input_data / block_size ("
4273 << inputDataDims.back() / blockSize << ")";
4274 }
4275 }
4276
4277 return success();
4278}
4279
4280LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4281 MLIRContext *context, ::std::optional<Location> location,
4282 CastToBlockScaledOp::Adaptor adaptor,
4283 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4284 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4285 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4286 if (!inputShape.hasRank())
4287 return success();
4288
4289 // Calculate output_scale shape if ranked input provided
4290 SmallVector<int64_t> outputScaleShape;
4291 inputShape.getDims(outputScaleShape);
4292 const int64_t lastDimLoc = inputShape.getRank() - 1;
4293 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4294 if (ShapedType::isStatic(lastDimSize)) {
4295 const unsigned int blockSize =
4296 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4297 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4298 }
4299 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4300 return success();
4301}
4302
4303LogicalResult CastToBlockScaledOp::verify() {
4304 const Type inputDataType = getInputData().getType();
4305 const Type outputDataType = getResult(0).getType();
4306 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4307 return emitOpError() << "require compatible shapes for input_data ("
4308 << inputDataType << ") and " << "output_data ("
4309 << outputDataType << ")";
4310
4311 const unsigned int blockSize =
4312 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4313 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4314 if (inputDataShape.hasRank()) {
4315 const int64_t inputDataLastDim =
4316 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4317 if (ShapedType::isStatic(inputDataLastDim) &&
4318 inputDataLastDim % blockSize != 0)
4319 return emitOpError() << "expect last dimension of input_data ("
4320 << inputDataLastDim
4321 << ") to be divisible by block_size (" << blockSize
4322 << ")";
4323 }
4324
4325 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4326 const Type outputScaleType = getResult(1).getType();
4327 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4328 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
4329 SmallVector<int64_t> outputDataDims, outputScaleDims;
4330 outputDataShape.getDims(outputDataDims);
4331 outputScaleShape.getDims(outputScaleDims);
4332
4333 if (outputDataDims.size() != outputScaleDims.size() ||
4335 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4336 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4337 return emitOpError() << "require compatible shapes for output_data ("
4338 << outputDataType << ") and " << "output_scale ("
4339 << outputScaleType
4340 << ") except for the last dimension";
4341
4342 const int64_t outputDataLastDim = outputDataDims.back();
4343 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4344 outputScaleDims.back()};
4345 if (ShapedType::isStatic(outputDataLastDim) &&
4346 failed(verifyCompatibleDims(dimsToCheck)))
4347 return emitOpError()
4348 << "expect last dimension of output_scale ("
4349 << outputScaleDims.back()
4350 << ") to be equal to last dimension of output_data / block_size ("
4351 << outputDataDims.back() / blockSize << ")";
4352 }
4353
4354 return success();
4355}
4356
4357LogicalResult IfOp::inferReturnTypeComponents(
4358 MLIRContext *context, ::std::optional<Location> location,
4359 IfOp::Adaptor adaptor,
4360 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4361 llvm::SmallVector<tosa::YieldOp> yieldOps;
4362 for (Region *region : adaptor.getRegions()) {
4363 for (auto &block : *region)
4364 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4365 yieldOps.push_back(returnOp);
4366 }
4367
4368 if (yieldOps.empty())
4369 return failure();
4370
4371 // Get the initial type information for the yield op.
4372 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4373 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4374 for (auto operand : yieldOps.front().getOperands()) {
4375 resultKnowledge.push_back(
4376 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4377 }
4378
4379 for (auto yieldOp : yieldOps) {
4380 if (resultKnowledge.size() != yieldOp.getNumOperands())
4381 return failure();
4382
4383 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4384 int32_t index = it.index();
4385 auto meet = ValueKnowledge::meet(
4386 resultKnowledge[index],
4387 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
4388 if (!meet)
4389 continue;
4390 resultKnowledge[index] = meet;
4391 }
4392 }
4393
4394 for (const ValueKnowledge &result : resultKnowledge) {
4395 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4396 }
4397
4398 return success();
4399}
4400
4401LogicalResult WhileOp::inferReturnTypeComponents(
4402 MLIRContext *context, ::std::optional<Location> location,
4403 WhileOp::Adaptor adaptor,
4404 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4405 llvm::SmallVector<tosa::YieldOp> yieldOps;
4406 for (auto &block : adaptor.getBodyGraph())
4407 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4408 yieldOps.push_back(returnOp);
4409
4410 // TOSA's while must have a tosa.yield as its terminator. If not found this
4411 // tosa.while is invalid.
4412 if (yieldOps.empty())
4413 return failure();
4414
4415 // Get the initial type information from the operand types.
4416 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4417 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4418 for (auto operand : yieldOps.front().getOperands()) {
4419 resultKnowledge.push_back(
4420 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4421 }
4422
4423 for (auto yieldOp : yieldOps) {
4424 if (resultKnowledge.size() != yieldOp.getNumOperands())
4425 return failure();
4426
4427 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4428 int32_t index = it.index();
4429 if (auto meet = ValueKnowledge::meet(
4430 resultKnowledge[index],
4431 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
4432 resultKnowledge[index] = meet;
4433 }
4434 }
4435 }
4436
4437 for (const ValueKnowledge &result : resultKnowledge) {
4438 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4439 }
4440
4441 return success();
4442}
4443
4444std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4445 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
4446 return llvm::to_vector<4>(vt.getShape());
4447 return std::nullopt;
4448}
4449
4451 Block::BlockArgListType blocksArgs,
4452 ValueRange initializers,
4453 StringRef prefix = "") {
4454 assert(blocksArgs.size() == initializers.size() &&
4455 "expected same length of arguments and initializers");
4456 if (initializers.empty())
4457 return;
4458
4459 parser << prefix << '(';
4460 llvm::interleaveComma(
4461 llvm::zip(blocksArgs, initializers), parser,
4462 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
4463 parser << ")";
4464}
4465
4466// parse and print of IfOp refer to the implementation of SCF dialect.
4467ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
4468 // Create the regions for 'then'.
4469 result.regions.reserve(2);
4470 Region *thenRegion = result.addRegion();
4471 Region *elseRegion = result.addRegion();
4472
4473 OpAsmParser::UnresolvedOperand cond;
4474
4475 if (parser.parseOperand(cond))
4476 return failure();
4477
4478 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4479 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4480
4481 // Parse the optional block arguments
4482 OptionalParseResult listResult =
4483 parser.parseOptionalAssignmentList(regionArgs, operands);
4484 if (listResult.has_value() && failed(listResult.value()))
4485 return failure();
4486
4487 // Parse a colon.
4488 if (failed(parser.parseColon()))
4489 return parser.emitError(parser.getCurrentLocation(),
4490 "expected type for condition operand");
4491
4492 // Parse the type of the condition operand
4493 Type condType;
4494 if (failed(parser.parseType(condType)))
4495 return parser.emitError(parser.getCurrentLocation(),
4496 "expected type for condition operand");
4497
4498 // Resolve operand with provided type
4499 if (failed(parser.resolveOperand(cond, condType, result.operands)))
4500 return failure();
4501
4502 // Parse optional block arg types
4503 if (listResult.has_value()) {
4504 FunctionType functionType;
4505
4506 if (failed(parser.parseType(functionType)))
4507 return parser.emitError(parser.getCurrentLocation())
4508 << "expected list of types for block arguments "
4509 << "followed by arrow type and list of return types";
4510
4511 result.addTypes(functionType.getResults());
4512
4513 if (functionType.getNumInputs() != operands.size()) {
4514 return parser.emitError(parser.getCurrentLocation())
4515 << "expected as many input types as operands " << "(expected "
4516 << operands.size() << " got " << functionType.getNumInputs()
4517 << ")";
4518 }
4519
4520 // Resolve input operands.
4521 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4522 parser.getCurrentLocation(),
4523 result.operands)))
4524 return failure();
4525 } else {
4526 // Parse optional results type list.
4527 if (parser.parseOptionalArrowTypeList(result.types))
4528 return failure();
4529 }
4530
4531 // Parse the 'then' region.
4532 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
4533 return failure();
4534
4535 // If we find an 'else' keyword then parse the 'else' region.
4536 if (!parser.parseOptionalKeyword("else")) {
4537 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
4538 return failure();
4539 }
4540
4541 // Parse the optional attribute list.
4542 if (parser.parseOptionalAttrDict(result.attributes))
4543 return failure();
4544 return success();
4545}
4546
4547void IfOp::print(OpAsmPrinter &p) {
4548 p << " " << getCondition();
4549
4550 printInitializationList(p, getThenGraph().front().getArguments(),
4551 getInputList(), " ");
4552 p << " : ";
4553 p << getCondition().getType();
4554
4555 if (!getInputList().empty()) {
4556 p << " (";
4557 llvm::interleaveComma(getInputList().getTypes(), p);
4558 p << ")";
4559 }
4560 p.printArrowTypeList(getResultTypes());
4561 p << " ";
4562
4563 p.printRegion(getThenGraph());
4564
4565 // Print the 'else' regions if it exists and has a block.
4566 auto &elseRegion = getElseGraph();
4567 if (!elseRegion.empty()) {
4568 p << " else ";
4569 p.printRegion(elseRegion);
4570 }
4571
4572 p.printOptionalAttrDict((*this)->getAttrs());
4573}
4574
4575LogicalResult IfOp::verify() {
4576 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4577 "'then_graph' arguments", getInputList(),
4578 "'input_list'")
4579 .failed())
4580 return failure();
4581
4582 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4583 "'else_graph' arguments", getInputList(),
4584 "'input_list'")
4585 .failed())
4586 return failure();
4587
4588 // MLIR will verify the absence of the terminator for us if otherwise.
4589 if (getThenGraph().front().mightHaveTerminator()) {
4590 auto thenYield =
4591 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4592 if (thenYield && errorIfTypeOrShapeMismatch(
4593 *this, thenYield.getInputs(), "'then_graph' results",
4594 getOutputList(), "'output_list'")
4595 .failed())
4596 return failure();
4597 }
4598
4599 // MLIR will verify the absence of the terminator for us if otherwise.
4600 if (getElseGraph().front().mightHaveTerminator()) {
4601 auto elseYield =
4602 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4603 if (elseYield && errorIfTypeOrShapeMismatch(
4604 *this, elseYield.getInputs(), "'else_graph' results",
4605 getOutputList(), "'output_list'")
4606 .failed())
4607 return failure();
4608 }
4609
4610 auto condType = getCondition().getType();
4611 if (errorIfShapeNotSizeOne(*this, condType).failed())
4612 return emitOpError() << "'condition' must be a size 1 tensor, got "
4613 << condType;
4614
4615 return success();
4616}
4617
4618LogicalResult WhileOp::verify() {
4619 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4620 getOutputList(), "'output_list'")
4621 .failed())
4622 return failure();
4623
4624 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4625 "'cond_graph' arguments", getInputList(),
4626 "'input_list'")
4627 .failed())
4628 return failure();
4629
4630 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4631 "'body_graph' arguments", getInputList(),
4632 "'input_list'")
4633 .failed())
4634 return failure();
4635
4636 if (getBodyGraph().front().mightHaveTerminator()) {
4637 auto bodyYield =
4638 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4639 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4640 "'body_graph' results",
4641 getInputList(), "'input_list'")
4642 .failed())
4643 return failure();
4644 }
4645
4646 // Condition block output must be a single element tensor with a single bool
4647 // value.
4648 if (!getCondGraph().front().mightHaveTerminator())
4649 return success();
4650
4651 auto condYield =
4652 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4653 if (!condYield)
4654 return success();
4655
4656 if (condYield.getInputs().size() != 1)
4657 return emitOpError() << "require 'cond_graph' only have one result";
4658
4659 auto condOutType = condYield.getInputs()[0].getType();
4660 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4661 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4662 << condOutType;
4663
4664 if (!getElementTypeOrSelf(condOutType).isInteger(1))
4665 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4666 << condOutType;
4667
4668 return success();
4669}
4670
4671LogicalResult ReverseOp::verify() {
4672 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4673 /* outType = */ getOutput().getType())
4674 .failed())
4675 return failure();
4676 TensorType inputType = getInput1().getType();
4677 TensorType outputType = getOutput().getType();
4678 int32_t reverseAxis = getAxis();
4679
4680 if (reverseAxis < 0)
4681 return emitOpError("expected non-negative reverse axis");
4682 if (inputType.hasRank()) {
4683 int64_t inputRank = inputType.getRank();
4684 // We allow for a special case where the input/output shape has rank 0 and
4685 // axis is also 0.
4686 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4687 return emitOpError("expect input tensor rank (")
4688 << inputRank << ") to be larger than reverse axis (" << reverseAxis
4689 << ")";
4690 }
4691 if (outputType.hasRank()) {
4692 int64_t outputRank = outputType.getRank();
4693 if (inputType.hasRank() && outputRank != inputType.getRank())
4694 return emitOpError(
4695 "expect output tensor rank to be equal to input tensor rank");
4696 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4697 return emitOpError("expect output tensor rank (")
4698 << outputRank << ") to be larger than reverse axis ("
4699 << reverseAxis << ")";
4700 }
4701 return success();
4702}
4703
4704LogicalResult tosa::SelectOp::verify() {
4705 // verify input2 and input3 have same element type as output
4706 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4707 /* outType = */ getOutput().getType())
4708 .failed() ||
4709 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4710 /* outType = */ getOutput().getType())
4711 .failed()) {
4712 return failure();
4713 }
4714 // verify input1 has element type of bool
4715 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4716 if (!predicateType) {
4717 return emitOpError("expect shaped tensor for input1, got ")
4718 << getInput1().getType();
4719 }
4720 auto predicateElementType = predicateType.getElementType();
4721 if (!predicateElementType.isInteger(1)) {
4722 return emitOpError("expect element type of bool for input1, got ")
4723 << predicateElementType;
4724 }
4725
4726 return success();
4727}
4728
4729LogicalResult tosa::VariableReadOp::verify() {
4730 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4731 .failed())
4732 return failure();
4733
4734 return success();
4735}
4736
4737LogicalResult tosa::VariableWriteOp::verify() {
4738 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4739 .failed())
4740 return failure();
4741
4742 return success();
4743}
4744
4745// parse and print of WhileOp refer to the implementation of SCF dialect.
4746ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4747 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4748 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4749 Region *cond = result.addRegion();
4750 Region *body = result.addRegion();
4751
4752 OptionalParseResult listResult =
4753 parser.parseOptionalAssignmentList(regionArgs, operands);
4754 if (listResult.has_value() && failed(listResult.value()))
4755 return failure();
4756
4757 FunctionType functionType;
4758 SMLoc typeLoc = parser.getCurrentLocation();
4759 if (failed(parser.parseColonType(functionType)))
4760 return failure();
4761
4762 result.addTypes(functionType.getResults());
4763
4764 if (functionType.getNumInputs() != operands.size()) {
4765 return parser.emitError(typeLoc)
4766 << "expected as many input types as operands " << "(expected "
4767 << operands.size() << " got " << functionType.getNumInputs() << ")";
4768 }
4769
4770 // Resolve input operands.
4771 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4772 parser.getCurrentLocation(),
4773 result.operands)))
4774 return failure();
4775
4776 // Propagate the types into the region arguments.
4777 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4778 regionArgs[i].type = functionType.getInput(i);
4779
4780 return failure(parser.parseRegion(*cond, regionArgs) ||
4781 parser.parseKeyword("do") || parser.parseRegion(*body) ||
4782 parser.parseOptionalAttrDictWithKeyword(result.attributes));
4783}
4784
4785void WhileOp::print(OpAsmPrinter &parser) {
4786 printInitializationList(parser, getCondGraph().front().getArguments(),
4787 getInputList(), " ");
4788 parser << " : ";
4789 parser.printFunctionalType(getInputList().getTypes(),
4790 getResults().getTypes());
4791 parser << ' ';
4792 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4793 parser << " do ";
4794 parser.printRegion(getBodyGraph());
4795 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4796}
4797
4798// Create a rank-1 const tensor for zero point of the source tensor.
4799std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4800 Location loc,
4801 Type srcElemType,
4802 int64_t zp) {
4803 srcElemType = getStorageElementTypeOrSelf(srcElemType);
4804 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4805 if (llvm::isa<FloatType>(srcElemType)) {
4806 auto zpAttr = DenseElementsAttr::get(
4807 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4808 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4809 }
4810 if (llvm::isa<IntegerType>(srcElemType)) {
4811 auto zpAttr =
4812 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4813 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4814 }
4815 llvm::errs() << "zero point is not allowed for unsupported data types\n";
4816 return std::nullopt;
4817}
4818
4819//===----------------------------------------------------------------------===//
4820// TOSA Shape and Shape Operators Helper functions.
4821//===----------------------------------------------------------------------===//
4822
4824 return mlir::isa<tosa::shapeType>(t);
4825}
4826
4827LogicalResult
4828mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
4829 int rank) {
4830 if (rank < 0)
4831 return emitError() << "invalid rank (must be >= 0): " << rank;
4832 return success();
4833}
4834
4836 for (auto v : op->getOperands()) {
4837 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4838 Operation *definingOp = v.getDefiningOp();
4839 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4840 return op->emitOpError("shape operand is not compile time resolvable");
4841 }
4842 }
4843 }
4844 return success();
4845}
4846
4847LogicalResult
4849 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
4850 return failure();
4851
4852 // delegate function that returns rank of shape type
4853 auto getRank = [](const Type type) {
4854 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4855 };
4856 auto operandTypes = op->getOperandTypes();
4857 auto resultTypes = op->getResultTypes();
4858
4859 auto rank = getRank(*op->getOperandTypes().begin());
4860 for (auto type : operandTypes) {
4861 if (getRank(type) != rank) {
4862 return op->emitOpError("operands don't have matching ranks");
4863 }
4864 }
4865 for (auto type : resultTypes) {
4866 if (getRank(type) != rank) {
4867 return op->emitOpError("result shape has different rank than operands");
4868 }
4869 }
4870 return success();
4871}
4872
4873//===----------------------------------------------------------------------===//
4874// TOSA Shape Operators verify functions.
4875//===----------------------------------------------------------------------===//
4876
4877LogicalResult tosa::ConstShapeOp::verify() {
4878 // check one dimensional rank
4879 auto valuesRank = getValues().getType().getRank();
4880 if (valuesRank != 1)
4881 return emitOpError("expect elements in attribute values with rank 1");
4882 // check that number of elements in values attr equal to rank of result shape
4883 auto count = getValues().getNumElements();
4884 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4885 if (count != rank && (count != 1 || rank != 0)) {
4886 return emitOpError("expect number of elements in attribute values (")
4887 << count << ") to be equal to the rank (" << rank
4888 << ") for the result shape type";
4889 }
4890 return success();
4891}
4892
4893LogicalResult tosa::DimOp::verify() {
4894 const tosa::shapeType outShapeType =
4895 cast<tosa::shapeType>(getResult().getType());
4896 if (outShapeType.getRank() != 1)
4897 return emitOpError("expect output shape type to contain one element, got ")
4898 << outShapeType;
4899
4900 const ShapeAdaptor inputType(getInput1().getType());
4901 if (inputType.hasRank()) {
4902 const int64_t inputRank = inputType.getRank();
4903 const int64_t axis = getAxisAttr().getInt();
4904 if (axis < 0 || axis >= inputRank)
4905 return emitOpError("expect axis to be in the range [0, ")
4906 << inputRank << "), got " << axis;
4907 }
4908 return success();
4909}
4910
4911LogicalResult tosa::ConcatShapeOp::verify() {
4912 const tosa::shapeType outShapeType =
4913 cast<tosa::shapeType>(getResult().getType());
4914 const int64_t outputRank = outShapeType.getRank();
4915 const Operation::operand_range inputList = getInput();
4916 const int64_t inputsRank =
4917 llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
4918 const tosa::shapeType inShapeType =
4919 cast<tosa::shapeType>(input.getType());
4920 return acc + inShapeType.getRank();
4921 });
4922 if (outputRank != inputsRank)
4923 return emitOpError("requires output shape rank to be equal to the sum of "
4924 "the input shape ranks (")
4925 << inputsRank << "), got " << outputRank;
4926
4927 return success();
4928}
4929
4930LogicalResult tosa::SliceShapeOp::verify() {
4931 std::optional<int32_t> start;
4932 DenseIntElementsAttr startAttr;
4933 if (matchPattern(getStart(), m_Constant(&startAttr)))
4934 start = startAttr.getValues<int32_t>()[0];
4935 if (start && start.value() < 0)
4936 return emitOpError("expected non-negative start index, got ")
4937 << start.value();
4938
4939 std::optional<int32_t> size;
4940 DenseIntElementsAttr sizeAttr;
4941 if (matchPattern(getSize(), m_Constant(&sizeAttr)))
4942 size = sizeAttr.getValues<int32_t>()[0];
4943 if (size && size.value() <= 0)
4944 return emitOpError("expected positive size, got ") << size.value();
4945
4946 if (!size)
4947 return success();
4948
4949 const tosa::shapeType outShapeType =
4950 cast<tosa::shapeType>(getResult().getType());
4951 const int64_t outputRank = outShapeType.getRank();
4952 if (outputRank != size)
4953 return emitOpError(
4954 "expected output type size to be equal to size attribute, got ")
4955 << outputRank << " vs " << size.value();
4956
4957 if (!start)
4958 return success();
4959
4960 const tosa::shapeType inShapeType =
4961 cast<tosa::shapeType>(getInput().getType());
4962 const int64_t inputRank = inShapeType.getRank();
4963 const int64_t sliceSize = start.value() + size.value();
4964 if (sliceSize > inputRank)
4965 return emitOpError("expected start + size to be less than or equal to "
4966 "input shape rank (")
4967 << inputRank << "), got " << sliceSize;
4968
4969 return success();
4970}
4971
4972//===----------------------------------------------------------------------===//
4973// TOSA Attribute Definitions.
4974//===----------------------------------------------------------------------===//
4975
4976#define GET_ATTRDEF_CLASSES
4977#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4978
4979//===----------------------------------------------------------------------===//
4980// TOSA Type Definitions.
4981//===----------------------------------------------------------------------===//
4982#define GET_TYPEDEF_CLASSES
4983#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4984
4985//===----------------------------------------------------------------------===//
4986// TOSA Operator Definitions.
4987//===----------------------------------------------------------------------===//
4988
4989#define GET_OP_CLASSES
4990#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:480
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1325
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:1014
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:137
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3189
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:584
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:978
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3214
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:679
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:987
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3404
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1439
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:567
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1453
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
Definition TosaOps.cpp:642
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3239
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:3307
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:2686
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1303
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:833
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:2616
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
Definition TosaOps.cpp:627
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:784
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3295
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:573
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3205
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1484
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1399
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1279
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1354
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:936
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:2644
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1080
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1573
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:193
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:207
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:129
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OperandRange operand_range
Definition Operation.h:371
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:4848
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:4835
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:226
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:251
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:619
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:4799
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:4823
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:604
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45